Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-09-11 10:26:46 +03:00 committed by GitHub
commit 8eb7ca051a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 157 additions and 10 deletions

View File

@ -1101,3 +1101,45 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x

View File

@ -570,7 +570,7 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"] "ipndm", "ipndm_v", "deis"]

View File

@ -183,17 +183,10 @@ class PreviewAudio(SaveAudio):
} }
class LoadAudio: class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [ files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
return {"required": {"audio": (sorted(files), {"audio_upload": True})}} return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio" CATEGORY = "audio"

View File

@ -2,7 +2,9 @@ from __future__ import annotations
import os import os
import time import time
import mimetypes
import logging import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@ -44,6 +46,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
extension_mimetypes_cache = {
"webp" : "image",
}
def map_legacy(folder_name: str) -> str: def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"} legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name) return legacy.get(folder_name, folder_name)
@ -89,6 +95,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory() return get_input_directory()
return None return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir # otherwise use default_path as base_dir

View File

@ -12,6 +12,8 @@ import json
import glob import glob
import struct import struct
import ssl import ssl
import socket
import ipaddress
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -80,6 +82,32 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware(): def create_origin_only_middleware():
@web.middleware @web.middleware
async def origin_only_middleware(request: web.Request, handler): async def origin_only_middleware(request: web.Request, handler):
@ -93,12 +121,16 @@ def create_origin_only_middleware():
parsed = urllib.parse.urlparse(origin) parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower() origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain) host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None: if host_domain_parsed.port is None:
origin_domain = parsed.hostname origin_domain = parsed.hostname
if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain: if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403) return web.Response(status=403)

View File

View File

@ -0,0 +1,52 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory
def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)
def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []