diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 87b9a87b4..14582956c 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -141,8 +141,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" + CublasOps = "cublas_ops" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 5b8d8000d..6388d3faf 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1422,3 +1422,101 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x + +@torch.no_grad() +def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): + ''' + SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 + Arxiv: https://arxiv.org/abs/2305.14267 + ''' + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() + h = t_next - t + h_eta = h * (eta + 1) + s = t + r * h + fac = 1 / (2 * r) + sigma_s = s.neg().exp() + + coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() + noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + + # Step 1 + x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s * s_in, **extra_args) + + # Step 2 + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (coeff_2 + 1) * x - coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + ''' + SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 + Arxiv: https://arxiv.org/abs/2305.14267 + ''' + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() + h = t_next - t + h_eta = h * (eta + 1) + s_1 = t + r_1 * h + s_2 = t + r_2 * h + sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() + noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + if inject_noise: + x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) + + # Step 3 + x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise + return x diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index e0f3057f7..f7f56b72c 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,6 @@ import torch -import comfy.ops +import comfy.rmsnorm + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): @@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): return torch.nn.functional.pad(img, pad, mode=padding_mode) -try: - rms_norm_torch = torch.nn.functional.rms_norm -except: - rms_norm_torch = None -def rms_norm(x, weight=None, eps=1e-6): - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) - else: - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) +rms_norm = comfy.rmsnorm.rms_norm diff --git a/comfy/ops.py b/comfy/ops.py index ced461011..6b0e29307 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,7 @@ import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float +import comfy.rmsnorm cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -146,6 +147,25 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): + def reset_parameters(self): + self.bias = None + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -357,6 +377,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op +CUBLAS_IS_AVAILABLE = False +try: + from cublas_ops import CublasLinear + CUBLAS_IS_AVAILABLE = True +except ImportError: + pass + +if CUBLAS_IS_AVAILABLE: + class cublas_ops(disable_weight_init): + class Linear(CublasLinear, disable_weight_init.Linear): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + return super().forward(input) + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: @@ -369,6 +408,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops + if ( + PerformanceFeature.CublasOps in args.fast and + CUBLAS_IS_AVAILABLE and + weight_dtype == torch.float16 and + (compute_dtype == torch.float16 or compute_dtype is None) + ): + logging.info("Using cublas ops") + return cublas_ops + if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py new file mode 100644 index 000000000..81b3e9062 --- /dev/null +++ b/comfy/rmsnorm.py @@ -0,0 +1,65 @@ +import torch +import comfy.model_management +import numbers + +RMSNorm = None + +try: + rms_norm_torch = torch.nn.functional.rms_norm + RMSNorm = torch.nn.RMSNorm +except: + rms_norm_torch = None + + +def rms_norm(x, weight=None, eps=1e-6): + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) + + +if RMSNorm is None: + class RMSNorm(torch.nn.Module): + def __init__( + self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs + ): + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def __init__( + self, + normalized_shape, + eps=None, + elementwise_affine=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = torch.nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) diff --git a/comfy/samplers.py b/comfy/samplers.py index 10728bd1f..27dfce45a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "er_sde"] + "gradient_estimation", "er_sde", "seeds_2", "seeds_3"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py new file mode 100644 index 000000000..f6928199b --- /dev/null +++ b/comfy_extras/nodes_optimalsteps.py @@ -0,0 +1,56 @@ +# from https://github.com/bebebe666/OptimalSteps + + +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + + +NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001], +"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001], +} + +class OptimalStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["FLUX", "Wan"], ), + "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "OptimalStepsScheduler": OptimalStepsScheduler, +} diff --git a/comfyui_version.py b/comfyui_version.py index 705622529..a44538d1a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.27" +__version__ = "0.3.28" diff --git a/folder_paths.py b/folder_paths.py index 72c70f594..9a525e5a1 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -85,6 +85,7 @@ cache_helper = CacheHelper() extension_mimetypes_cache = { "webp" : "image", + "fbx" : "model", } def map_legacy(folder_name: str) -> str: @@ -140,11 +141,14 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None -def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]: +def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory()) - filter_files_content_types(files, ["image", "audio", "video"]) + videos = filter_files_content_types(files, ["video"]) + + Note: + - 'model' in MIME context refers to 3D models, not files containing trained weights and parameters """ global extension_mimetypes_cache result = [] diff --git a/main.py b/main.py index ff582ff22..295d0f4e2 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from app.logger import setup_logger import itertools import utils.extra_config import logging +import sys import comfyui_manager if __name__ == "__main__": @@ -310,6 +311,7 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. + logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) event_loop, _, start_all_func = start_comfyui() diff --git a/nodes.py b/nodes.py index 5b8077ab7..68505b952 100644 --- a/nodes.py +++ b/nodes.py @@ -1655,6 +1655,7 @@ class LoadImage: def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["image"]) return {"required": {"image": (sorted(files), {"image_upload": True})}, } @@ -2284,6 +2285,7 @@ def init_builtin_extra_nodes(): "nodes_hunyuan3d.py", "nodes_primitive.py", "nodes_cfg.py", + "nodes_optimalsteps.py" ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index db9e776cd..6eb1704db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.27" +version = "0.3.28" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/tests-unit/folder_paths_test/filter_by_content_types_test.py b/tests-unit/folder_paths_test/filter_by_content_types_test.py index 423677a60..683f9fc11 100644 --- a/tests-unit/folder_paths_test/filter_by_content_types_test.py +++ b/tests-unit/folder_paths_test/filter_by_content_types_test.py @@ -1,14 +1,17 @@ import pytest import os import tempfile -from folder_paths import filter_files_content_types +from folder_paths import filter_files_content_types, extension_mimetypes_cache +from unittest.mock import patch + @pytest.fixture(scope="module") def file_extensions(): return { 'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'], 'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'], - 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] + 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'], + 'model': ['gltf', 'glb', 'obj', 'fbx', 'stl'] } @@ -22,7 +25,18 @@ def mock_dir(file_extensions): yield directory -def test_categorizes_all_correctly(mock_dir, file_extensions): +@pytest.fixture +def patched_mimetype_cache(file_extensions): + # Mock model file extensions since they may not be in the test-runner system's mimetype cache + new_cache = extension_mimetypes_cache.copy() + for extension in file_extensions["model"]: + new_cache[extension] = "model" + + with patch("folder_paths.extension_mimetypes_cache", new_cache): + yield + + +def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache): files = os.listdir(mock_dir) for content_type, extensions in file_extensions.items(): filtered_files = filter_files_content_types(files, [content_type]) @@ -30,7 +44,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions): assert f"sample_{content_type}.{extension}" in filtered_files -def test_categorizes_all_uniquely(mock_dir, file_extensions): +def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache): files = os.listdir(mock_dir) for content_type, extensions in file_extensions.items(): filtered_files = filter_files_content_types(files, [content_type])