Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-09-13 12:45:23 -07:00
commit ffb4ed9cf2
80 changed files with 8866 additions and 7391 deletions

25
.github/workflows/test-unit.yml vendored Normal file
View File

@ -0,0 +1,25 @@
name: Unit Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

1
.gitignore vendored
View File

@ -15,6 +15,7 @@
.vscode/ .vscode/
.idea/ .idea/
venv/ venv/
.venv/
/web/extensions/* /web/extensions/*
!/web/extensions/logging.js.example !/web/extensions/logging.js.example
!/web/extensions/core/ !/web/extensions/core/

View File

@ -631,6 +631,8 @@ The default installation includes a fast latent preview method that's low-resolu
| Alt + `+` | Canvas Zoom in | | Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out | | Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | | Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue | | Q | Toggle visibility of the queue |
| H | Toggle visibility of history | | H | Toggle visibility of history |
| R | Refresh graph | | R | Refresh graph |

View File

@ -4,16 +4,21 @@ import os
import re import re
import shutil import shutil
import uuid import uuid
from urllib import parse
from aiohttp import web from aiohttp import web
from .app_settings import AppSettings from .app_settings import AppSettings
from ..cli_args import args from ..cli_args import args
from ..cmd.folder_paths import user_directory from ..cmd import folder_paths
default_user = "default"
class UserManager(): class UserManager():
def __init__(self): def __init__(self):
user_directory = folder_paths.get_user_directory()
self.default_user = "default" self.default_user = "default"
self.users_file = os.path.join(user_directory, "users.json") self.users_file = os.path.join(user_directory, "users.json")
self.settings = AppSettings(self) self.settings = AppSettings(self)
@ -21,14 +26,17 @@ class UserManager():
os.mkdir(user_directory) os.mkdir(user_directory)
if args.multi_user: if args.multi_user:
if os.path.isfile(self.users_file): if os.path.isfile(self.get_users_file()):
with open(self.users_file) as f: with open(self.get_users_file()) as f:
self.users = json.load(f) self.users = json.load(f)
else: else:
self.users = {} self.users = {}
else: else:
self.users = {"default": "default"} self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request): def get_request_user_id(self, request):
user = "default" user = "default"
if args.multi_user and "comfy-user" in request.headers: if args.multi_user and "comfy-user" in request.headers:
@ -40,6 +48,7 @@ class UserManager():
return user return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
if type == "userdata": if type == "userdata":
root_dir = user_directory root_dir = user_directory
@ -54,6 +63,10 @@ class UserManager():
raise PermissionError() raise PermissionError()
if file is not None: if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user} # prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file)) path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root: if os.path.commonpath((user_root, path)) != user_root:
@ -75,7 +88,7 @@ class UserManager():
self.users[user_id] = name self.users[user_id] = name
with open(self.users_file, "w") as f: with open(self.get_users_file(), "w") as f:
json.dump(self.users, f) json.dump(self.users, f)
return user_id return user_id
@ -108,23 +121,40 @@ class UserManager():
async def listuserdata(request): async def listuserdata(request):
directory = request.rel_url.query.get('dir', '') directory = request.rel_url.query.get('dir', '')
if not directory: if not directory:
return web.Response(status=400) return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory) path = self.get_request_user_filepath(request, directory)
if not path: if not path:
return web.Response(status=403) return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path): if not os.path.exists(path):
return web.Response(status=404) return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true" recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join( full_info = request.rel_url.query.get('full_info', '').lower() == "true"
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true" split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path: if split_path and not full_info:
results = [[x] + x.split(os.sep) for x in results] results = [[x] + x.split('/') for x in results]
return web.json_response(results) return web.json_response(results)

View File

@ -71,6 +71,7 @@ class CacheKeySetInputSignature(CacheKeySet):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.immediate_node_signature = {}
self.add_keys(node_ids) self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool: def include_node_id_in_input(self) -> bool:
@ -98,11 +99,13 @@ class CacheKeySetInputSignature(CacheKeySet):
if not dynprompt.has_node(node_id): if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it. # This node doesn't exist -- we can't cache it.
return [float("NaN")] return [float("NaN")]
if node_id in self.immediate_node_signature: # reduce repeated calls of ancestors
return self.immediate_node_signature[node_id]
node = dynprompt.get_node(node_id) node = dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values():
signature.append(node_id) signature.append(node_id)
inputs = node["inputs"] inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):
@ -112,6 +115,7 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket))) signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket)))
else: else:
signature.append((key, inputs[key])) signature.append((key, inputs[key]))
self.immediate_node_signature[node_id] = signature
return signature return signature
# This function returns a list of all ancestors of the given node. The order of the list is # This function returns a list of all ancestors of the given node. The order of the list is

View File

@ -215,6 +215,8 @@ def _create_parser() -> EnhancedConfigArgParser:
default=None default=None
) )
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
# now give plugins a chance to add configuration # now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'): for entry_point in entry_points().select(group='comfyui.custom_config'):
try: try:

View File

@ -112,6 +112,7 @@ class Configuration(dict):
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512. preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -200,6 +201,7 @@ class Configuration(dict):
self.executor_factory: str = "ThreadPoolExecutor" self.executor_factory: str = "ThreadPoolExecutor"
self.openai_api_key: Optional[str] = None self.openai_api_key: Optional[str] = None
self.user_directory: Optional[str] = None
def __getattr__(self, item): def __getattr__(self, item):
if item not in self: if item not in self:

View File

@ -229,7 +229,13 @@ def merge_result_data(results, obj):
# merge node execution results # merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list): for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list: if is_list:
output.append([x for o in results for x in o[i]]) value = []
for o in results:
if isinstance(o[i], ExecutionBlocker):
value.append(o[i])
else:
value.extend(o[i])
output.append(value)
else: else:
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output

View File

@ -1,17 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import mimetypes
import os import os
import time import time
from typing import Optional, List, Final from typing import Optional, List, Final, Literal
from .folder_paths_pre import get_base_path from .folder_paths_pre import get_base_path
from ..component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
from ..component_model.folder_path_types import extension_mimetypes_cache as _extension_mimetypes_cache
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
from ..component_model.module_property import module_property from ..component_model.module_property import module_property
supported_pt_extensions: Final[frozenset[str]] = _supported_pt_extensions supported_pt_extensions: Final[frozenset[str]] = _supported_pt_extensions
extension_mimetypes_cache: Final[dict[str, str]] = _extension_mimetypes_cache
# todo: this needs to be wrapped in a context and configurable # todo: this needs to be wrapped in a context and configurable
@ -87,6 +90,15 @@ def get_input_directory():
return input_directory return input_directory
def get_user_directory() -> str:
return user_directory
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
# NOTE: used in http server so don't put folders that should not be accessed remotely # NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name): def get_directory_by_type(type_name):
if type_name == "output": if type_name == "output":
@ -277,18 +289,25 @@ def get_filename_list(folder_name):
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename): def map_filename(filename: str) -> tuple[int, str]:
prefix_len = len(os.path.basename(filename_prefix)) prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1] prefix = filename[:prefix_len + 1]
try: try:
digits = int(filename[prefix_len + 1:].split('_')[0]) digits = int(filename[prefix_len + 1:].split('_')[0])
except: except:
digits = 0 digits = 0
return (digits, prefix) return digits, prefix
def compute_vars(input, image_width, image_height): def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width)) input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height)) input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height) filename_prefix = compute_vars(filename_prefix, image_width, image_height)
@ -328,3 +347,27 @@ def create_directories():
def invalidate_cache(folder_name): def invalidate_cache(folder_name):
global _filename_list_cache global _filename_list_cache
_filename_list_cache.pop(folder_name, None) _filename_list_cache.pop(folder_name, None)
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result

View File

@ -125,6 +125,11 @@ async def main(from_script_dir: Optional[Path] = None):
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
# configure extra model paths earlier # configure extra model paths earlier
try: try:
extra_model_paths_config_path = os.path.join(os_getcwd, "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os_getcwd, "extra_model_paths.yaml")

View File

@ -2,13 +2,16 @@ from __future__ import annotations
import asyncio import asyncio
import glob import glob
import ipaddress
import json import json
import logging import logging
import mimetypes import mimetypes
import os import os
import socket
import struct import struct
import sys import sys
import traceback import traceback
import urllib
import uuid import uuid
from asyncio import Future, AbstractEventLoop from asyncio import Future, AbstractEventLoop
from enum import Enum from enum import Enum
@ -102,6 +105,69 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
# this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
# in that case the Host and Origin hostnames won't match
# I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
host_domain = host.lower()
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
# limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: # if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
return response
return origin_only_middleware
class PromptServer(ExecutorToClientProgress): class PromptServer(ExecutorToClientProgress):
instance: Optional['PromptServer'] = None instance: Optional['PromptServer'] = None
@ -129,6 +195,8 @@ class PromptServer(ExecutorToClientProgress):
middlewares = [cache_control] middlewares = [cache_control]
if args.enable_cors_header: if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header)) middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
middlewares.append(create_origin_only_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app: web.Application = web.Application(client_max_size=max_upload_size, self.app: web.Application = web.Application(client_max_size=max_upload_size,

View File

@ -5,6 +5,9 @@ import os
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
extension_mimetypes_cache = {
"webp": "image",
}
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -513,7 +513,9 @@ def load_controlnet_flux_instantx(sd):
if union_cnet in new_sd: if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0] num_union_modes = new_sd[union_cnet].shape[0]
control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd) control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.Flux() latent_format = latent_formats.Flux()

View File

@ -93,30 +93,44 @@ class TopologicalSort:
self.add_strong_link(from_node_id, from_socket, to_node_id) self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id): def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id) if not self.is_cached(from_node_id):
if to_node_id not in self.blocking[from_node_id]: self.add_node(from_node_id)
self.blocking[from_node_id][to_node_id] = {} if to_node_id not in self.blocking[from_node_id]:
self.blockCount[to_node_id] += 1 self.blocking[from_node_id][to_node_id] = {}
self.blocking[from_node_id][to_node_id][from_socket] = True self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes: node_ids = [node_unique_id]
return links = []
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"] while len(node_ids) > 0:
for input_name in inputs: unique_id = node_ids.pop()
value = inputs[input_name] if unique_id in self.pendingNodes:
if is_link(value): continue
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes: self.pendingNodes[unique_id] = True
continue self.blockCount[unique_id] = 0
input_type, input_category, input_info = self.get_input_info(unique_id, input_name) self.blocking[unique_id] = {}
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy: inputs = self.dynprompt.get_node(unique_id)["inputs"]
self.add_strong_link(from_node_id, from_socket, unique_id) for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)
def is_cached(self, node_id):
return False
def get_ready_nodes(self): def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
@ -142,11 +156,8 @@ class ExecutionList(TopologicalSort):
self.output_cache = output_cache self.output_cache = output_cache
self.staged_node_id = None self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id): def is_cached(self, node_id):
if self.output_cache.get(from_node_id) is not None: return self.output_cache.get(node_id) is not None
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self): def stage_node_execution(self):
assert self.staged_node_id is None assert self.staged_node_id is None

View File

@ -1126,3 +1126,45 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x

View File

@ -16,7 +16,7 @@ except:
rms_norm_torch = None rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6): def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None: if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else: else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)

View File

@ -53,7 +53,7 @@ class MistolineControlnetBlock(nn.Module):
class ControlNetFlux(Flux): class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19 self.main_model_double = 19
@ -81,7 +81,12 @@ class ControlNetFlux(Flux):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.latent_input = latent_input self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input: if not self.latent_input:
if self.mistoline: if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations) self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)

View File

@ -199,9 +199,13 @@ def load_lora(lora, to_load):
def model_lora_keys_clip(model, key_map={}): def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys() sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False
clip_g_present = False
for b in range(32): # TODO: clean up for b in range(32): # TODO: clean up
for c in LORA_CLIP_MAP: for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@ -225,6 +229,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
clip_g_present = True
if clip_l_present: if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # SDXL base lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # SDXL base
key_map[lora_key] = k key_map[lora_key] = k
@ -240,10 +245,18 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."): # OneTrainer SD3 lora if k.startswith("t5xxl.transformer."): # OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")] l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) t5_index = 1
key_map[lora_key] = k if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): # HunyuanDiT Lora elif k.startswith("hydit_clip.transformer.bert."): # HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))

View File

@ -30,7 +30,7 @@ from . import utils
from .float import stochastic_rounding from .float import stochastic_rounding
from .model_base import BaseModel from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements from .model_management_types import ModelManageable, MemoryMeasurements
from .types import UnetWrapperFunction from .comfy_types import UnetWrapperFunction
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF

View File

@ -504,6 +504,7 @@ class CheckpointLoader:
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
DEPRECATED = True
def load_checkpoint(self, config_name, ckpt_name): def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)

View File

@ -5,7 +5,7 @@ import collections
from . import model_management from . import model_management
import math import math
import logging import logging
import scipy import scipy.stats
import numpy import numpy
from . import sampler_helpers from . import sampler_helpers
@ -573,7 +573,7 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"] "ipndm", "ipndm_v", "deis"]

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,6 @@
var __defProp = Object.defineProperty; var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as LGraphNode, e as applyTextReplacements, f as ComfyWidgets, g as addValueControlWidgets, D as DraggableList, h as api, u as useToastStore, i as LGraphGroup } from "./index-CI3N807S.js"; import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as LGraphNode, e as applyTextReplacements, f as ComfyWidgets, g as addValueControlWidgets, D as DraggableList, h as api, i as LGraphGroup, u as useToastStore } from "./index-Dfv2aLsq.js";
class ClipspaceDialog extends ComfyDialog { class ClipspaceDialog extends ComfyDialog {
static { static {
__name(this, "ClipspaceDialog"); __name(this, "ClipspaceDialog");
@ -3650,7 +3650,7 @@ app.registerExtension({
content: "Add Group For Selected Nodes", content: "Add Group For Selected Nodes",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length, disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: /* @__PURE__ */ __name(() => { callback: /* @__PURE__ */ __name(() => {
var group2 = new LiteGraph.LGraphGroup(); const group2 = new LGraphGroup();
addNodesToGroup(group2, this.selected_nodes); addNodesToGroup(group2, this.selected_nodes);
app.canvas.graph.add(group2); app.canvas.graph.add(group2);
this.graph.change(); this.graph.change();
@ -5088,7 +5088,7 @@ app.registerExtension({
data = JSON.parse(data); data = JSON.parse(data);
const nodeIds = Object.keys(app.canvas.selected_nodes); const nodeIds = Object.keys(app.canvas.selected_nodes);
for (let i = 0; i < nodeIds.length; i++) { for (let i = 0; i < nodeIds.length; i++) {
const node = app.graph.getNodeById(Number.parseInt(nodeIds[i])); const node = app.graph.getNodeById(nodeIds[i]);
const nodeData = node?.constructor.nodeData; const nodeData = node?.constructor.nodeData;
let groupData = GroupNodeHandler.getGroupData(node); let groupData = GroupNodeHandler.getGroupData(node);
if (groupData) { if (groupData) {
@ -5955,7 +5955,7 @@ app.registerExtension({
}, },
onNodeOutputsUpdated(nodeOutputs) { onNodeOutputsUpdated(nodeOutputs) {
for (const [nodeId, output] of Object.entries(nodeOutputs)) { for (const [nodeId, output] of Object.entries(nodeOutputs)) {
const node = app.graph.getNodeById(Number.parseInt(nodeId)); const node = app.graph.getNodeById(nodeId);
if ("audio" in output) { if ("audio" in output) {
const audioUIWidget = node.widgets.find( const audioUIWidget = node.widgets.find(
(w) => w.name === "audioUI" (w) => w.name === "audioUI"
@ -6026,4 +6026,4 @@ app.registerExtension({
}; };
} }
}); });
//# sourceMappingURL=index-BD-Ia1C4.js.map //# sourceMappingURL=index-CrROdkG4.js.map

1
comfy/web/assets/index-CrROdkG4.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1475,21 +1475,21 @@
width: 5rem !important; width: 5rem !important;
} }
.info-chip[data-v-25bd5f50] { .info-chip[data-v-ffbfdf57] {
background: transparent; background: transparent;
} }
.setting-item[data-v-25bd5f50] { .setting-item[data-v-ffbfdf57] {
display: flex; display: flex;
justify-content: space-between; justify-content: space-between;
align-items: center; align-items: center;
margin-bottom: 1rem; margin-bottom: 1rem;
} }
.setting-label[data-v-25bd5f50] { .setting-label[data-v-ffbfdf57] {
display: flex; display: flex;
align-items: center; align-items: center;
flex: 1; flex: 1;
} }
.setting-input[data-v-25bd5f50] { .setting-input[data-v-ffbfdf57] {
flex: 1; flex: 1;
display: flex; display: flex;
justify-content: flex-end; justify-content: flex-end;
@ -1497,19 +1497,19 @@
} }
/* Ensure PrimeVue components take full width of their container */ /* Ensure PrimeVue components take full width of their container */
.setting-input[data-v-25bd5f50] .p-inputtext, .setting-input[data-v-ffbfdf57] .p-inputtext,
.setting-input[data-v-25bd5f50] .input-slider, .setting-input[data-v-ffbfdf57] .input-slider,
.setting-input[data-v-25bd5f50] .p-select, .setting-input[data-v-ffbfdf57] .p-select,
.setting-input[data-v-25bd5f50] .p-togglebutton { .setting-input[data-v-ffbfdf57] .p-togglebutton {
width: 100%; width: 100%;
max-width: 200px; max-width: 200px;
} }
.setting-input[data-v-25bd5f50] .p-inputtext { .setting-input[data-v-ffbfdf57] .p-inputtext {
max-width: unset; max-width: unset;
} }
/* Special case for ToggleSwitch to align it to the right */ /* Special case for ToggleSwitch to align it to the right */
.setting-input[data-v-25bd5f50] .p-toggleswitch { .setting-input[data-v-ffbfdf57] .p-toggleswitch {
margin-left: auto; margin-left: auto;
} }
@ -1655,21 +1655,21 @@
margin-left: 0.5rem; margin-left: 0.5rem;
} }
.comfy-error-report[data-v-12539d86] { .comfy-error-report[data-v-a103fd62] {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 1rem; gap: 1rem;
} }
.action-container[data-v-12539d86] { .action-container[data-v-a103fd62] {
display: flex; display: flex;
gap: 1rem; gap: 1rem;
justify-content: flex-end; justify-content: flex-end;
} }
.wrapper-pre[data-v-12539d86] { .wrapper-pre[data-v-a103fd62] {
white-space: pre-wrap; white-space: pre-wrap;
word-wrap: break-word; word-wrap: break-word;
} }
.no-results-placeholder[data-v-12539d86] { .no-results-placeholder[data-v-a103fd62] {
padding-top: 0; padding-top: 0;
} }
.lds-ring { .lds-ring {
@ -3158,7 +3158,7 @@ body {
overflow: hidden; overflow: hidden;
grid-template-columns: auto 1fr auto; grid-template-columns: auto 1fr auto;
grid-template-rows: auto 1fr auto; grid-template-rows: auto 1fr auto;
background-color: var(--bg-color); background: var(--bg-color) var(--bg-img);
color: var(--fg-color); color: var(--fg-color);
min-height: -webkit-fill-available; min-height: -webkit-fill-available;
max-height: -webkit-fill-available; max-height: -webkit-fill-available;
@ -3833,17 +3833,19 @@ audio.comfy-audio.empty-audio-widget {
box-sizing: border-box; box-sizing: border-box;
} }
.node-title-editor[data-v-77799b26] { .group-title-editor.node-title-editor[data-v-f0cbabc5] {
z-index: 9999; z-index: 9999;
padding: 0.25rem; padding: 0.25rem;
} }
[data-v-77799b26] .editable-text { [data-v-f0cbabc5] .editable-text {
width: 100%; width: 100%;
height: 100%; height: 100%;
} }
[data-v-77799b26] .editable-text input { [data-v-f0cbabc5] .editable-text input {
width: 100%; width: 100%;
height: 100%; height: 100%;
/* Override the default font size */
font-size: inherit;
} }
.side-bar-button-icon { .side-bar-button-icon {
@ -4086,26 +4088,26 @@ audio.comfy-audio.empty-audio-widget {
color: var(--error-text); color: var(--error-text);
} }
.comfy-vue-node-search-container[data-v-077af1a9] { .comfy-vue-node-search-container[data-v-d28bffc4] {
display: flex; display: flex;
width: 100%; width: 100%;
min-width: 24rem; min-width: 24rem;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
} }
.comfy-vue-node-search-container[data-v-077af1a9] * { .comfy-vue-node-search-container[data-v-d28bffc4] * {
pointer-events: auto; pointer-events: auto;
} }
.comfy-vue-node-preview-container[data-v-077af1a9] { .comfy-vue-node-preview-container[data-v-d28bffc4] {
position: absolute; position: absolute;
left: -350px; left: -350px;
top: 50px; top: 50px;
} }
.comfy-vue-node-search-box[data-v-077af1a9] { .comfy-vue-node-search-box[data-v-d28bffc4] {
z-index: 10; z-index: 10;
flex-grow: 1; flex-grow: 1;
} }
.option-container[data-v-077af1a9] { .option-container[data-v-d28bffc4] {
display: flex; display: flex;
width: 100%; width: 100%;
cursor: pointer; cursor: pointer;
@ -4117,12 +4119,12 @@ audio.comfy-audio.empty-audio-widget {
padding-top: 0px; padding-top: 0px;
padding-bottom: 0px; padding-bottom: 0px;
} }
.option-display-name[data-v-077af1a9] { .option-display-name[data-v-d28bffc4] {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
font-weight: 600; font-weight: 600;
} }
.option-category[data-v-077af1a9] { .option-category[data-v-d28bffc4] {
overflow: hidden; overflow: hidden;
text-overflow: ellipsis; text-overflow: ellipsis;
font-size: 0.875rem; font-size: 0.875rem;
@ -4133,7 +4135,7 @@ audio.comfy-audio.empty-audio-widget {
/* Keeps the text on a single line by default */ /* Keeps the text on a single line by default */
white-space: nowrap; white-space: nowrap;
} }
[data-v-077af1a9] .highlight { [data-v-d28bffc4] .highlight {
background-color: var(--p-primary-color); background-color: var(--p-primary-color);
color: var(--p-primary-contrast-color); color: var(--p-primary-contrast-color);
font-weight: bold; font-weight: bold;
@ -4141,10 +4143,10 @@ audio.comfy-audio.empty-audio-widget {
padding: 0rem 0.125rem; padding: 0rem 0.125rem;
margin: -0.125rem 0.125rem; margin: -0.125rem 0.125rem;
} }
._filter-button[data-v-077af1a9] { ._filter-button[data-v-d28bffc4] {
z-index: 10; z-index: 10;
} }
._dialog[data-v-077af1a9] { ._dialog[data-v-d28bffc4] {
min-width: 24rem; min-width: 24rem;
} }
@ -4353,28 +4355,60 @@ img.galleria-image {
gap: 0.5rem; gap: 0.5rem;
} }
.node-tree-leaf[data-v-adf5f221] { .tree-node[data-v-d4b7b060] {
width: 100%; width: 100%;
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: space-between; justify-content: space-between;
} }
.node-content[data-v-adf5f221] { .leaf-count-badge[data-v-d4b7b060] {
margin-left: 0.5rem;
}
.node-content[data-v-d4b7b060] {
display: flex; display: flex;
align-items: center; align-items: center;
flex-grow: 1; flex-grow: 1;
} }
.node-label[data-v-adf5f221] { .leaf-label[data-v-d4b7b060] {
margin-left: 0.5rem; margin-left: 0.5rem;
} }
.bookmark-button[data-v-adf5f221] { [data-v-d4b7b060] .editable-text span {
width: unset; word-break: break-all;
padding: 0.25rem;
} }
.node-tree-folder[data-v-f2d72e9b] { [data-v-9d3310b9] .tree-explorer-node-label {
width: 100%;
display: flex; display: flex;
align-items: center; align-items: center;
margin-left: var(--p-tree-node-gap);
flex-grow: 1;
}
/*
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
* we can create a visual indicator for the drop target without affecting the layout of other elements.
*/
[data-v-9d3310b9] .p-tree-node-content:has(.tree-folder) {
position: relative;
}
[data-v-9d3310b9] .p-tree-node-content:has(.tree-folder.can-drop)::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
border: 1px solid var(--p-content-color);
pointer-events: none;
}
.node-lib-node-container[data-v-3238e135] {
height: 100%;
width: 100%;
}
.bookmark-button[data-v-3238e135] {
width: unset;
padding: 0.25rem;
} }
.p-selectbutton .p-button[data-v-91077f2a] { .p-selectbutton .p-button[data-v-91077f2a] {
@ -4394,45 +4428,33 @@ img.galleria-image {
gap: 0.5rem; gap: 0.5rem;
} }
.node-lib-tree-node-label {
display: flex;
align-items: center;
margin-left: var(--p-tree-node-gap);
flex-grow: 1;
}
.node-lib-filter-popup { .node-lib-filter-popup {
margin-left: -13px; margin-left: -13px;
} }
[data-v-87967891] .node-lib-search-box { [data-v-85688f44] .node-lib-search-box {
margin-left: 1rem; margin-left: 1rem;
margin-right: 1rem; margin-right: 1rem;
margin-top: 1rem; margin-top: 1rem;
} }
[data-v-87967891] .comfy-vue-side-bar-body { [data-v-85688f44] .comfy-vue-side-bar-body {
background: var(--p-tree-background); background: var(--p-tree-background);
} }
[data-v-85688f44] .node-lib-bookmark-tree-explorer {
/* padding-bottom: 2px;
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
* we can create a visual indicator for the drop target without affecting the layout of other elements.
*/
[data-v-87967891] .p-tree-node-content:has(.node-tree-folder) {
position: relative;
} }
[data-v-87967891] .p-tree-node-content:has(.node-tree-folder.can-drop)::after { [data-v-85688f44] .node-lib-tree-explorer {
content: ''; padding-top: 2px;
position: absolute; }
top: 0; [data-v-85688f44] .p-divider {
left: 0; margin: var(--comfy-tree-explorer-item-padding) 0px;
right: 0;
bottom: 0;
border: 1px solid var(--p-content-color);
pointer-events: none;
} }
.spinner[data-v-8616e7a1] { .p-tree-node-content {
padding: var(--comfy-tree-explorer-item-padding) !important;
}
.spinner[data-v-75e4840f] {
position: absolute; position: absolute;
inset: 0px; inset: 0px;
display: flex; display: flex;

View File

@ -1,6 +1,6 @@
var __defProp = Object.defineProperty; var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js"; import { j as createSpinner, h as api, $ as $el } from "./index-Dfv2aLsq.js";
class UserSelectionScreen { class UserSelectionScreen {
static { static {
__name(this, "UserSelectionScreen"); __name(this, "UserSelectionScreen");
@ -117,4 +117,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export { export {
UserSelectionScreen UserSelectionScreen
}; };
//# sourceMappingURL=userSelection-CyXKCVy3.js.map //# sourceMappingURL=userSelection-DSpF-zVD.js.map

File diff suppressed because one or more lines are too long

View File

@ -1,2 +1,2 @@
// Shim for extensions\core\clipspace.ts // Shim for extensions/core/clipspace.ts
export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog; export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog;

View File

@ -1,3 +1,3 @@
// Shim for extensions\core\groupNode.ts // Shim for extensions/core/groupNode.ts
export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig; export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig;
export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler; export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler;

View File

@ -1,2 +1,2 @@
// Shim for extensions\core\groupNodeManage.ts // Shim for extensions/core/groupNodeManage.ts
export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog; export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog;

View File

@ -1,4 +1,4 @@
// Shim for extensions\core\widgetInputs.ts // Shim for extensions/core/widgetInputs.ts
export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig; export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig;
export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig; export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig;
export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid; export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid;

100
comfy/web/index.html vendored
View File

@ -1,50 +1,50 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>ComfyUI</title> <title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<!-- Browser Test Fonts --> <!-- Browser Test Fonts -->
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet"> <!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<style> <style>
* { * {
font-family: 'Roboto Mono', 'Noto Color Emoji'; font-family: 'Roboto Mono', 'Noto Color Emoji';
} }
</style> --> </style> -->
<link rel="stylesheet" type="text/css" href="user.css" /> <link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" /> <link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script> <script type="module" crossorigin src="./assets/index-Dfv2aLsq.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css"> <link rel="stylesheet" crossorigin href="./assets/index-W4jP-SrU.css">
</head> </head>
<body class="litegraph"> <body class="litegraph">
<div id="vue-app"></div> <div id="vue-app"></div>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;"> <div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner"> <main class="comfy-user-selection-inner">
<h1>ComfyUI</h1> <h1>ComfyUI</h1>
<form> <form>
<section> <section>
<label>New user: <label>New user:
<input placeholder="Enter a username" /> <input placeholder="Enter a username" />
</label> </label>
</section> </section>
<div class="comfy-user-existing"> <div class="comfy-user-existing">
<span class="or-separator">OR</span> <span class="or-separator">OR</span>
<section> <section>
<label> <label>
Existing user: Existing user:
<select> <select>
<option hidden disabled selected value> Select a user </option> <option hidden disabled selected value> Select a user </option>
</select> </select>
</label> </label>
</section> </section>
</div> </div>
<footer> <footer>
<span class="comfy-user-error">&nbsp;</span> <span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button> <button class="comfy-btn comfy-user-button-next">Next</button>
</footer> </footer>
</form> </form>
</main> </main>
</div> </div>
</body> </body>
</html> </html>

View File

@ -1,2 +1,2 @@
// Shim for scripts\api.ts // Shim for scripts/api.ts
export const api = window.comfyAPI.api.api; export const api = window.comfyAPI.api.api;

View File

@ -1,4 +1,4 @@
// Shim for scripts\app.ts // Shim for scripts/app.ts
export const ANIM_PREVIEW_WIDGET = window.comfyAPI.app.ANIM_PREVIEW_WIDGET; export const ANIM_PREVIEW_WIDGET = window.comfyAPI.app.ANIM_PREVIEW_WIDGET;
export const ComfyApp = window.comfyAPI.app.ComfyApp; export const ComfyApp = window.comfyAPI.app.ComfyApp;
export const app = window.comfyAPI.app.app; export const app = window.comfyAPI.app.app;

View File

@ -1,2 +1,2 @@
// Shim for scripts\changeTracker.ts // Shim for scripts/changeTracker.ts
export const ChangeTracker = window.comfyAPI.changeTracker.ChangeTracker; export const ChangeTracker = window.comfyAPI.changeTracker.ChangeTracker;

View File

@ -1,2 +1,2 @@
// Shim for scripts\defaultGraph.ts // Shim for scripts/defaultGraph.ts
export const defaultGraph = window.comfyAPI.defaultGraph.defaultGraph; export const defaultGraph = window.comfyAPI.defaultGraph.defaultGraph;

View File

@ -1,2 +1,2 @@
// Shim for scripts\domWidget.ts // Shim for scripts/domWidget.ts
export const addDomClippingSetting = window.comfyAPI.domWidget.addDomClippingSetting; export const addDomClippingSetting = window.comfyAPI.domWidget.addDomClippingSetting;

View File

@ -1,2 +1,2 @@
// Shim for scripts\logging.ts // Shim for scripts/logging.ts
export const ComfyLogging = window.comfyAPI.logging.ComfyLogging; export const ComfyLogging = window.comfyAPI.logging.ComfyLogging;

View File

@ -1,3 +1,3 @@
// Shim for scripts\metadata\flac.ts // Shim for scripts/metadata/flac.ts
export const getFromFlacBuffer = window.comfyAPI.flac.getFromFlacBuffer; export const getFromFlacBuffer = window.comfyAPI.flac.getFromFlacBuffer;
export const getFromFlacFile = window.comfyAPI.flac.getFromFlacFile; export const getFromFlacFile = window.comfyAPI.flac.getFromFlacFile;

View File

@ -1,3 +1,3 @@
// Shim for scripts\metadata\png.ts // Shim for scripts/metadata/png.ts
export const getFromPngBuffer = window.comfyAPI.png.getFromPngBuffer; export const getFromPngBuffer = window.comfyAPI.png.getFromPngBuffer;
export const getFromPngFile = window.comfyAPI.png.getFromPngFile; export const getFromPngFile = window.comfyAPI.png.getFromPngFile;

View File

@ -1,4 +1,4 @@
// Shim for scripts\pnginfo.ts // Shim for scripts/pnginfo.ts
export const getPngMetadata = window.comfyAPI.pnginfo.getPngMetadata; export const getPngMetadata = window.comfyAPI.pnginfo.getPngMetadata;
export const getFlacMetadata = window.comfyAPI.pnginfo.getFlacMetadata; export const getFlacMetadata = window.comfyAPI.pnginfo.getFlacMetadata;
export const getWebpMetadata = window.comfyAPI.pnginfo.getWebpMetadata; export const getWebpMetadata = window.comfyAPI.pnginfo.getWebpMetadata;

View File

@ -1,4 +1,4 @@
// Shim for scripts\ui.ts // Shim for scripts/ui.ts
export const ComfyDialog = window.comfyAPI.ui.ComfyDialog; export const ComfyDialog = window.comfyAPI.ui.ComfyDialog;
export const $el = window.comfyAPI.ui.$el; export const $el = window.comfyAPI.ui.$el;
export const ComfyUI = window.comfyAPI.ui.ComfyUI; export const ComfyUI = window.comfyAPI.ui.ComfyUI;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\components\asyncDialog.ts // Shim for scripts/ui/components/asyncDialog.ts
export const ComfyAsyncDialog = window.comfyAPI.asyncDialog.ComfyAsyncDialog; export const ComfyAsyncDialog = window.comfyAPI.asyncDialog.ComfyAsyncDialog;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\components\button.ts // Shim for scripts/ui/components/button.ts
export const ComfyButton = window.comfyAPI.button.ComfyButton; export const ComfyButton = window.comfyAPI.button.ComfyButton;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\components\buttonGroup.ts // Shim for scripts/ui/components/buttonGroup.ts
export const ComfyButtonGroup = window.comfyAPI.buttonGroup.ComfyButtonGroup; export const ComfyButtonGroup = window.comfyAPI.buttonGroup.ComfyButtonGroup;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\components\popup.ts // Shim for scripts/ui/components/popup.ts
export const ComfyPopup = window.comfyAPI.popup.ComfyPopup; export const ComfyPopup = window.comfyAPI.popup.ComfyPopup;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\components\splitButton.ts // Shim for scripts/ui/components/splitButton.ts
export const ComfySplitButton = window.comfyAPI.splitButton.ComfySplitButton; export const ComfySplitButton = window.comfyAPI.splitButton.ComfySplitButton;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\dialog.ts // Shim for scripts/ui/dialog.ts
export const ComfyDialog = window.comfyAPI.dialog.ComfyDialog; export const ComfyDialog = window.comfyAPI.dialog.ComfyDialog;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\draggableList.ts // Shim for scripts/ui/draggableList.ts
export const DraggableList = window.comfyAPI.draggableList.DraggableList; export const DraggableList = window.comfyAPI.draggableList.DraggableList;

View File

@ -1,3 +1,3 @@
// Shim for scripts\ui\imagePreview.ts // Shim for scripts/ui/imagePreview.ts
export const calculateImageGrid = window.comfyAPI.imagePreview.calculateImageGrid; export const calculateImageGrid = window.comfyAPI.imagePreview.calculateImageGrid;
export const createImageHost = window.comfyAPI.imagePreview.createImageHost; export const createImageHost = window.comfyAPI.imagePreview.createImageHost;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\menu\index.ts // Shim for scripts/ui/menu/index.ts
export const ComfyAppMenu = window.comfyAPI.index.ComfyAppMenu; export const ComfyAppMenu = window.comfyAPI.index.ComfyAppMenu;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\menu\interruptButton.ts // Shim for scripts/ui/menu/interruptButton.ts
export const getInterruptButton = window.comfyAPI.interruptButton.getInterruptButton; export const getInterruptButton = window.comfyAPI.interruptButton.getInterruptButton;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\menu\queueButton.ts // Shim for scripts/ui/menu/queueButton.ts
export const ComfyQueueButton = window.comfyAPI.queueButton.ComfyQueueButton; export const ComfyQueueButton = window.comfyAPI.queueButton.ComfyQueueButton;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\menu\queueOptions.ts // Shim for scripts/ui/menu/queueOptions.ts
export const ComfyQueueOptions = window.comfyAPI.queueOptions.ComfyQueueOptions; export const ComfyQueueOptions = window.comfyAPI.queueOptions.ComfyQueueOptions;

View File

@ -1,3 +1,3 @@
// Shim for scripts\ui\menu\workflows.ts // Shim for scripts/ui/menu/workflows.ts
export const ComfyWorkflowsMenu = window.comfyAPI.workflows.ComfyWorkflowsMenu; export const ComfyWorkflowsMenu = window.comfyAPI.workflows.ComfyWorkflowsMenu;
export const ComfyWorkflowsContent = window.comfyAPI.workflows.ComfyWorkflowsContent; export const ComfyWorkflowsContent = window.comfyAPI.workflows.ComfyWorkflowsContent;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\settings.ts // Shim for scripts/ui/settings.ts
export const ComfySettingsDialog = window.comfyAPI.settings.ComfySettingsDialog; export const ComfySettingsDialog = window.comfyAPI.settings.ComfySettingsDialog;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\spinner.ts // Shim for scripts/ui/spinner.ts
export const createSpinner = window.comfyAPI.spinner.createSpinner; export const createSpinner = window.comfyAPI.spinner.createSpinner;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\toggleSwitch.ts // Shim for scripts/ui/toggleSwitch.ts
export const toggleSwitch = window.comfyAPI.toggleSwitch.toggleSwitch; export const toggleSwitch = window.comfyAPI.toggleSwitch.toggleSwitch;

View File

@ -1,2 +1,2 @@
// Shim for scripts\ui\userSelection.ts // Shim for scripts/ui/userSelection.ts
export const UserSelectionScreen = window.comfyAPI.userSelection.UserSelectionScreen; export const UserSelectionScreen = window.comfyAPI.userSelection.UserSelectionScreen;

View File

@ -1,3 +1,3 @@
// Shim for scripts\ui\utils.ts // Shim for scripts/ui/utils.ts
export const applyClasses = window.comfyAPI.utils.applyClasses; export const applyClasses = window.comfyAPI.utils.applyClasses;
export const toggleElement = window.comfyAPI.utils.toggleElement; export const toggleElement = window.comfyAPI.utils.toggleElement;

View File

@ -1,4 +1,4 @@
// Shim for scripts\utils.ts // Shim for scripts/utils.ts
export const clone = window.comfyAPI.utils.clone; export const clone = window.comfyAPI.utils.clone;
export const applyTextReplacements = window.comfyAPI.utils.applyTextReplacements; export const applyTextReplacements = window.comfyAPI.utils.applyTextReplacements;
export const addStylesheet = window.comfyAPI.utils.addStylesheet; export const addStylesheet = window.comfyAPI.utils.addStylesheet;

View File

@ -1,4 +1,4 @@
// Shim for scripts\widgets.ts // Shim for scripts/widgets.ts
export const updateControlWidgetLabel = window.comfyAPI.widgets.updateControlWidgetLabel; export const updateControlWidgetLabel = window.comfyAPI.widgets.updateControlWidgetLabel;
export const addValueControlWidget = window.comfyAPI.widgets.addValueControlWidget; export const addValueControlWidget = window.comfyAPI.widgets.addValueControlWidget;
export const addValueControlWidgets = window.comfyAPI.widgets.addValueControlWidgets; export const addValueControlWidgets = window.comfyAPI.widgets.addValueControlWidgets;

View File

@ -1,4 +1,4 @@
// Shim for scripts\workflows.ts // Shim for scripts/workflows.ts
export const trimJsonExt = window.comfyAPI.workflows.trimJsonExt; export const trimJsonExt = window.comfyAPI.workflows.trimJsonExt;
export const ComfyWorkflowManager = window.comfyAPI.workflows.ComfyWorkflowManager; export const ComfyWorkflowManager = window.comfyAPI.workflows.ComfyWorkflowManager;
export const ComfyWorkflow = window.comfyAPI.workflows.ComfyWorkflow; export const ComfyWorkflow = window.comfyAPI.workflows.ComfyWorkflow;

View File

@ -206,17 +206,10 @@ class PreviewAudio(SaveAudio):
class LoadAudio: class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [ files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
return {"required": {"audio": (sorted(files), {"audio_upload": True})}} return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio" CATEGORY = "audio"

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
from enum import Enum
import torch import torch
@ -41,6 +42,41 @@ def extract_lora(diff, rank):
return (U, Vh) return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave: class LoraSave:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -48,9 +84,12 @@ class LoraSave:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}), "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
}, },
"optional": {"model_diff": ("MODEL",), }, "optional": {"model_diff": ("MODEL",),
"text_encoder_diff": ("CLIP",)},
} }
RETURN_TYPES = () RETURN_TYPES = ()
@ -59,30 +98,18 @@ class LoraSave:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, model_diff=None): def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None: if model_diff is None and text_encoder_diff is None:
return {} return {}
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {} output_sd = {}
prefix_key = "diffusion_model." if model_diff is not None:
stored = set() output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
if text_encoder_diff is not None:
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
sd = model_diff.model_state_dict(filter_prefix=prefix_key)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if weight_diff.ndim < 2:
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}.lora_up.weight".format(k[:-7])] = out[0].contiguous().half().cpu()
output_sd["{}.lora_down.weight".format(k[:-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

View File

@ -24,6 +24,7 @@ class PerpNeg:
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale): def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()

View File

@ -0,0 +1,21 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -0,0 +1,66 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import pytest
import os
import tempfile
from unittest.mock import patch
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
def test_get_directory_by_type():
test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir
assert folder_paths.get_directory_by_type("invalid") is None
def test_annotated_filepath():
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
def test_get_annotated_filepath():
default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path():
folder_paths.add_model_folder_path("test_folder", "/test/path")
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
files, dirs = folder_paths.recursive_search(temp_dir)
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
assert os.path.samefile(full_output_folder, temp_dir)
assert filename == "test"
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"

View File

View File

@ -0,0 +1,52 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory
def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)
def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

@ -0,0 +1,124 @@
import pytest
import yaml
import os
from unittest.mock import Mock, patch, mock_open
from utils.extra_config import load_extra_path_config
import folder_paths
@pytest.fixture
def mock_yaml_content():
return {
'test_config': {
'base_path': '~/App/',
'checkpoints': 'subfolder1',
}
}
@pytest.fixture
def mock_expanded_home():
return '/home/user'
@pytest.fixture
def yaml_config_with_appdata():
return """
test_config:
base_path: '%APPDATA%/ComfyUI'
checkpoints: 'models/checkpoints'
"""
@pytest.fixture
def mock_yaml_content_appdata(yaml_config_with_appdata):
return yaml.safe_load(yaml_config_with_appdata)
@pytest.fixture
def mock_expandvars_appdata():
mock = Mock()
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
return mock
@pytest.fixture
def mock_add_model_folder_path():
return Mock()
@pytest.fixture
def mock_expanduser(mock_expanded_home):
def _expanduser(path):
if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:])
return path
return _expanduser
@pytest.fixture
def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content)
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
):
# Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_calls = [
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1')),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
# Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once()
# Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
@patch('builtins.open', new_callable=mock_open)
def test_load_extra_model_paths_expands_appdata(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expandvars_appdata,
yaml_config_with_appdata,
mock_yaml_content_appdata
):
# Set the mock_file to return yaml with appdata as a variable
mock_file.return_value.read.return_value = yaml_config_with_appdata
# Attach mocks
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
# Mock expanduser to do nothing (since we're not testing it here)
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
expected_calls = [
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints')),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
# Verify that expandvars was called
assert mock_expandvars_appdata.called

View File

@ -429,3 +429,29 @@ class TestExecution:
assert len(images) == 1, "Should have 1 image" assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached" assert not result.did_run(test_node), "The execution should have been cached"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked.
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
int3 = g.node("StubInt", value=3)
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0))
result = client.run(g)
assert result.did_run(output), "The execution should have run"
images = result.get_images(output)
assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"

View File

@ -0,0 +1,122 @@
import os
from unittest.mock import patch
import pytest
from aiohttp import web
from comfy.app.user_manager import UserManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"

0
utils/__init__.py Normal file
View File

25
utils/extra_config.py Normal file
View File

@ -0,0 +1,25 @@
import os
import yaml
import folder_paths
import logging
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
base_path = os.path.expandvars(os.path.expanduser(base_path))
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)