diff --git a/.gitignore b/.gitignore index b1e3be697..59a9ea74a 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,5 @@ dmypy.json cython_debug/ .openapi-generator/ -/tests-ui/data/object_info.json \ No newline at end of file +/tests-ui/data/object_info.json +/user/ \ No newline at end of file diff --git a/README.md b/README.md index 78b501711..53c6d9677 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,8 @@ There is a portable standalone build for Windows that should work for running on Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +If you have trouble extracting it, right click the file -> properties -> unblock + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. @@ -192,7 +194,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from To run tests: ```shell pytest tests/inference - (cd tests-ui && npm ci && npm test:generate && npm test) + (cd tests-ui && npm ci && npm run test:generate && npm test) ``` You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language. diff --git a/comfy/app/__init__.py b/comfy/app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/app/app_settings.py b/comfy/app/app_settings.py new file mode 100644 index 000000000..8c6edc56c --- /dev/null +++ b/comfy/app/app_settings.py @@ -0,0 +1,54 @@ +import os +import json +from aiohttp import web + + +class AppSettings(): + def __init__(self, user_manager): + self.user_manager = user_manager + + def get_settings(self, request): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + if os.path.isfile(file): + with open(file) as f: + return json.load(f) + else: + return {} + + def save_settings(self, request, settings): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + with open(file, "w") as f: + f.write(json.dumps(settings, indent=4)) + + def add_routes(self, routes): + @routes.get("/settings") + async def get_settings(request): + return web.json_response(self.get_settings(request)) + + @routes.get("/settings/{id}") + async def get_setting(request): + value = None + settings = self.get_settings(request) + setting_id = request.match_info.get("id", None) + if setting_id and setting_id in settings: + value = settings[setting_id] + return web.json_response(value) + + @routes.post("/settings") + async def post_settings(request): + settings = self.get_settings(request) + new_settings = await request.json() + self.save_settings(request, {**settings, **new_settings}) + return web.Response(status=200) + + @routes.post("/settings/{id}") + async def post_setting(request): + setting_id = request.match_info.get("id", None) + if not setting_id: + return web.Response(status=400) + settings = self.get_settings(request) + settings[setting_id] = await request.json() + self.save_settings(request, settings) + return web.Response(status=200) \ No newline at end of file diff --git a/comfy/app/user_manager.py b/comfy/app/user_manager.py new file mode 100644 index 000000000..b5cf84c89 --- /dev/null +++ b/comfy/app/user_manager.py @@ -0,0 +1,135 @@ +import json +import os +import re +import uuid +from aiohttp import web +from ..cli_args import args +from ..cmd.folder_paths import user_directory +from .app_settings import AppSettings + + +class UserManager(): + def __init__(self): + self.default_user = "default" + self.users_file = os.path.join(user_directory, "users.json") + self.settings = AppSettings(self) + if not os.path.exists(user_directory): + os.mkdir(user_directory) + if not args.multi_user: + print("****** User settings have been changed to be stored on the server instead of browser storage. ******") + print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") + + if args.multi_user: + if os.path.isfile(self.users_file): + with open(self.users_file) as f: + self.users = json.load(f) + else: + self.users = {} + else: + self.users = {"default": "default"} + + def get_request_user_id(self, request): + user = "default" + if args.multi_user and "comfy-user" in request.headers: + user = request.headers["comfy-user"] + + if user not in self.users: + raise KeyError("Unknown user: " + user) + + return user + + def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): + + if type == "userdata": + root_dir = user_directory + else: + raise KeyError("Unknown filepath type:" + type) + + user = self.get_request_user_id(request) + path = user_root = os.path.abspath(os.path.join(root_dir, user)) + + # prevent leaving /{type} + if os.path.commonpath((root_dir, user_root)) != root_dir: + return None + + parent = user_root + + if file is not None: + # prevent leaving /{type}/{user} + path = os.path.abspath(os.path.join(user_root, file)) + if os.path.commonpath((user_root, path)) != user_root: + return None + + if create_dir and not os.path.exists(parent): + os.mkdir(parent) + + return path + + def add_user(self, name): + name = name.strip() + if not name: + raise ValueError("username not provided") + user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + user_id = user_id + "_" + str(uuid.uuid4()) + + self.users[user_id] = name + + with open(self.users_file, "w") as f: + json.dump(self.users, f) + + return user_id + + def add_routes(self, routes): + self.settings.add_routes(routes) + + @routes.get("/users") + async def get_users(request): + if args.multi_user: + return web.json_response({"storage": "server", "users": self.users}) + else: + user_dir = self.get_request_user_filepath(request, None, create_dir=False) + return web.json_response({ + "storage": "server", + "migrated": os.path.exists(user_dir) + }) + + @routes.post("/users") + async def post_users(request): + body = await request.json() + username = body["username"] + if username in self.users.values(): + return web.json_response({"error": "Duplicate username."}, status=400) + + user_id = self.add_user(username) + return web.json_response(user_id) + + @routes.get("/userdata/{file}") + async def getuserdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + if not os.path.exists(path): + return web.Response(status=404) + + return web.FileResponse(path) + + @routes.post("/userdata/{file}") + async def post_userdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + body = await request.read() + with open(path, "wb") as f: + f.write(body) + + return web.Response(status=200) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a9d23b446..5560f0511 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,8 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") +parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") + if options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7397b7a26..09e7bbca1 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module): self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) def forward(self, x, mask=None, intermediate_output=None): - optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) if intermediate_output is not None: if intermediate_output < 0: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index d3fdd5223..84c3699c8 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,4 +1,4 @@ -from .utils import load_torch_file, transformers_convert +from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os import torch import json @@ -43,6 +43,9 @@ class ClipVisionModel(): def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) + def get_sd(self): + return self.model.state_dict() + def encode_image(self, image): model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device)).float() @@ -75,6 +78,9 @@ def convert_to_transformers(sd, prefix): sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd = transformers_convert(sd, prefix, "vision_model.", 48) + else: + replace_prefix = {prefix: ""} + sd = state_dict_prefix_replace(sd, replace_prefix) return sd def load_clipvision_from_sd(sd, prefix="", convert_keys=False): diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 59ba52cf9..4b62a14e8 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import sys import copy import datetime import heapq @@ -15,6 +16,7 @@ from typing import Tuple import sys import gc import inspect +from typing import List, Literal, NamedTuple, Optional import torch @@ -325,11 +327,22 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): + self.success = None + self.server = server + self.reset() + + def reset(self): self.outputs = {} self.object_storage = {} self.outputs_ui = {} + self.status_messages = [] + self.success = True self.old_prompt = {} - self.server = server + + def add_message(self, event, data, broadcast: bool): + self.status_messages.append((event, data)) + if self.server.client_id is not None or broadcast: + self.server.send_sync(event, data, self.server.client_id) def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -344,23 +357,22 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), } - self.server.send_sync("execution_interrupted", mes, self.server.client_id) + self.add_message("execution_interrupted", mes, broadcast=True) else: - if self.server.client_id is not None: - mes = { - "prompt_id": prompt_id, - "node_id": node_id, - "node_type": class_type, - "executed": list(executed), - - "exception_message": error["exception_message"], - "exception_type": error["exception_type"], - "traceback": error["traceback"], - "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], - } - self.server.send_sync("execution_error", mes, self.server.client_id) + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.add_message("execution_error", mes, broadcast=False) + # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: @@ -381,8 +393,8 @@ class PromptExecutor: else: self.server.client_id = None - if self.server.client_id is not None: - self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id) + self.status_messages = [] + self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): # delete cached outputs if nodes don't exist for them @@ -415,9 +427,9 @@ class PromptExecutor: del d model_management.cleanup_models() - if self.server.client_id is not None: - self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, - self.server.client_id) + self.add_message("execution_cached", + {"nodes": list(current_outputs), "prompt_id": prompt_id}, + broadcast=False) executed = set() output_node_id = None to_execute = [] @@ -434,11 +446,9 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, - self.outputs_ui, self.object_storage) - if success is not True: - self.handle_execution_error( prompt_id, - prompt, current_outputs, executed, error, ex) + self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) + if self.success is not True: + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break for x in executed: @@ -779,6 +789,7 @@ class PromptQueue: self.queue = [] self.currently_running = {} self.history = {} + self.flags = {} server.prompt_queue = self def size(self) -> int: @@ -803,15 +814,28 @@ class PromptQueue: self.server.queue_updated() return copy.deepcopy(item_with_future.queue_tuple), task_id - def task_done(self, item_id, outputs: dict): + class ExecutionStatus(NamedTuple): + status_str: Literal['success', 'error'] + completed: bool + messages: List[str] + + def task_done(self, item_id, outputs: dict, + status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: queue_item = self.currently_running.pop(item_id) prompt = queue_item.queue_tuple if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": time.time()} - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] + + status_dict: Optional[dict] = None + if status is not None: + status_dict = copy.deepcopy(status._asdict()) + + self.history[prompt[1]] = { + "prompt": prompt, + "outputs": copy.deepcopy(outputs), + 'status': status_dict, + } self.server.queue_updated() if queue_item.completed: queue_item.completed.set_result(outputs) @@ -880,3 +904,17 @@ class PromptQueue: def delete_history_item(self, id_to_delete: int): with self.mutex: self.history.pop(id_to_delete, None) + + def set_flag(self, name, data): + with self.mutex: + self.flags[name] = data + self.not_empty.notify() + + def get_flags(self, reset=True): + with self.mutex: + if reset: + ret = self.flags + self.flags = {} + return ret + else: + return self.flags.copy() diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 678d6d080..e854b5d08 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -31,11 +31,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) +folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) + folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") +user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") filename_list_cache = {} @@ -139,15 +142,27 @@ def recursive_search(directory, excluded_dir_names=None): excluded_dir_names = [] result = [] - dirs = {directory: os.path.getmtime(directory)} + dirs = {} + + # Attempt to add the initial directory to dirs with error handling + try: + dirs[directory] = os.path.getmtime(directory) + except FileNotFoundError: + print(f"Warning: Unable to access {directory}. Skipping this path.") + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) result.append(relative_path) + for d in subdirs: path = os.path.join(dirpath, d) - dirs[path] = os.path.getmtime(path) + try: + dirs[path] = os.path.getmtime(path) + except FileNotFoundError: + print(f"Warning: Unable to access {path}. Skipping this path.") + continue return result, dirs def filter_files_extensions(files, extensions): diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 38abd96fe..3d3d89f81 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -89,7 +89,7 @@ def prompt_worker(q, _server): gc_collect_interval = 10.0 current_time = 0.0 while True: - timeout = None + timeout = 1000.0 if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) @@ -102,14 +102,32 @@ def prompt_worker(q, _server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True - q.task_done(item_id, e.outputs_ui) + q.task_done(item_id, + e.outputs_ui, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages)) if _server.client_id is not None: - _server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id) + _server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time print("Prompt executed in {:.2f} seconds".format(execution_time)) + flags = q.get_flags() + free_memory = flags.get("free_memory", False) + + if flags.get("unload_models", free_memory): + model_management.unload_all_models() + need_gc = True + last_gc_collect = 0 + + if free_memory: + e.reset() + need_gc = True + last_gc_collect = 0 + if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index ea133e6aa..eea89cc5f 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -35,6 +35,7 @@ from ..vendor.appdirs import user_data_dir nodes = import_all_nodes_in_workspace() +from ..app.user_manager import UserManager class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -91,6 +92,7 @@ class PromptServer(): mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' + self.user_manager = UserManager() self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop @@ -532,6 +534,17 @@ class PromptServer(): model_management.interrupt_current_processing() return web.Response(status=200) + @routes.post("/free") + async def post_free(request): + json_data = await request.json() + unload_models = json_data.get("unload_models", False) + free_memory = json_data.get("free_memory", False) + if unload_models: + self.prompt_queue.set_flag("unload_models", unload_models) + if free_memory: + self.prompt_queue.set_flag("free_memory", free_memory) + return web.Response(status=200) + @routes.post("/history") async def post_history(request): json_data = await request.json() @@ -671,6 +684,7 @@ class PromptServer(): return web.json_response(prompt, status=200) def add_routes(self): + self.user_manager.add_routes(self.routes) self.app.add_routes(self.routes) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): @@ -768,14 +782,9 @@ class PromptServer(): site = web.TCPSite(runner, address, port) await site.start() - address_to_print = 'localhost' - if address == '' or address == '0.0.0.0': - address = '0.0.0.0' - else: - address_to_print = address if verbose: print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(address_to_print, port)) + print("To see the GUI go to: http://{}:{}".format(address, port)) if call_on_start is not None: call_on_start(address, port) diff --git a/comfy/conds.py b/comfy/conds.py index 9677ff49f..76ca1cfd1 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,4 +1,3 @@ -import enum import torch import math from . import utils diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b11ad1f7..f1cd2caf5 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,7 +1,6 @@ import torch import math import os -import contextlib from . import utils from . import model_management @@ -127,7 +126,10 @@ class ControlBase: if o[i] is None: o[i] = prev_val else: - o[i] += prev_val + if o[i].shape[0] < prev_val.shape[0]: + o[i] = prev_val + o[i] + else: + o[i] += prev_val return out class ControlNet(ControlBase): diff --git a/comfy/gligen.py b/comfy/gligen.py index f47baa0f1..7701af542 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,7 +1,7 @@ import math import torch -from torch import nn, einsum +from torch import nn from .ldm.modules.attention import CrossAttention from inspect import isfunction diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 17dda7c58..eba49c7ef 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,12 +1,9 @@ -from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from functools import partial - from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -176,6 +173,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=False, upcast_attention=upcast_attention, + mask=mask, ) hidden_states = hidden_states.to(dtype) @@ -238,6 +236,12 @@ def attention_split(q, k, v, heads, mask=None): else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 first_op_done = True @@ -293,11 +297,14 @@ def attention_xformers(q, k, v, heads, mask=None): (q, k, v), ) - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if exists(mask): - raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) @@ -322,7 +329,6 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic -optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -338,15 +344,18 @@ else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -if model_management.pytorch_attention_enabled(): - optimized_attention_masked = attention_pytorch +optimized_attention_masked = optimized_attention -def optimized_attention_for_device(device, mask=False): - if device == torch.device("cpu"): #TODO +def optimized_attention_for_device(device, mask=False, small_input=False): + if small_input: if model_management.pytorch_attention_enabled(): - return attention_pytorch + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases else: return attention_basic + + if device == torch.device("cpu"): + return attention_sub_quad + if mask: return optimized_attention_masked diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 903dc2801..2c9ae3a0e 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -1,12 +1,9 @@ from abc import abstractmethod -import math -import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from functools import partial from .util import ( checkpoint, diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index 66767b587..a5d866030 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -23,13 +23,13 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) x = self.scale(x) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) z = self.unscale(z) noise_level = self.time_embed(noise_level) return z, noise_level diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 37bc79948..9e1d2f165 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -61,6 +61,7 @@ def _summarize_chunk( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> AttnChunk: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -84,6 +85,8 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() attn_weights -= max_score + if mask is not None: + attn_weights += mask torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) @@ -96,11 +99,12 @@ def _query_chunk_attention( value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, + mask, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape - def chunk_scanner(chunk_idx: int) -> AttnChunk: + def chunk_scanner(chunk_idx: int, mask) -> AttnChunk: key_chunk = dynamic_slice( key_t, (0, 0, chunk_idx), @@ -111,10 +115,13 @@ def _query_chunk_attention( (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) - return summarize_chunk(query, key_chunk, value_chunk) + if mask is not None: + mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + + return summarize_chunk(query, key_chunk, value_chunk, mask=mask) chunks: List[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk @@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking( beta=0, ) + if mask is not None: + attn_scores += mask try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores @@ -183,6 +193,7 @@ def efficient_dot_product_attention( kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, + mask = None, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -209,13 +220,22 @@ def efficient_dot_product_attention( if kv_chunk_size_min is not None: kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + if mask is not None and len(mask.shape) == 2: + mask = mask.unsqueeze(0) + def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - + + def get_mask_chunk(chunk_idx: int) -> Tensor: + if mask is None: + return None + chunk = min(query_chunk_size, q_tokens) + return mask[:,chunk_idx:chunk_idx + chunk] + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( @@ -237,6 +257,7 @@ def efficient_dot_product_attention( query=query, key_t=key_t, value=value, + mask=mask, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, @@ -246,6 +267,7 @@ def efficient_dot_product_attention( query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, + mask=get_mask_chunk(i * query_chunk_size) ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res diff --git a/comfy/model_base.py b/comfy/model_base.py index 7a0802e2b..ded8fd198 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -6,7 +6,6 @@ from . import model_management from . import conds from . import ops from enum import Enum -import contextlib from . import utils class ModelType(Enum): @@ -100,11 +99,29 @@ class BaseModel(torch.nn.Module): if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] - denoise_mask = kwargs.get("denoise_mask", None) - latent_image = kwargs.get("latent_image", None) + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + concat_latent_image = kwargs.get("concat_latent_image", None) + if concat_latent_image is None: + concat_latent_image = kwargs.get("latent_image", None) + else: + concat_latent_image = self.process_latent_in(concat_latent_image) + noise = kwargs.get("noise", None) device = kwargs["device"] + if concat_latent_image.shape[1:] != noise.shape[1:]: + concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) + + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] + + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -117,9 +134,9 @@ class BaseModel(torch.nn.Module): for ck in concat_keys: if denoise_mask is not None: if ck == "mask": - cond_concat.append(denoise_mask[:,:1].to(device)) + cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) @@ -159,19 +176,28 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict, vae_state_dict): - clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + extra_sds = [] + if clip_state_dict is not None: + extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) + if vae_state_dict is not None: + extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) + if clip_vision_state_dict is not None: + extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: - clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) - vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds) if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) - return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + for sd in extra_sds: + unet_state_dict.update(sd) + + return unet_state_dict def set_inpaint(self): self.inpaint_model = True @@ -190,7 +216,7 @@ class BaseModel(torch.nn.Module): return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) -def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] noise_aug = [] @@ -199,7 +225,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge weight = unclip_cond["strength"] noise_augment = unclip_cond["noise_augmentation"] noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) noise_aug.append(noise_augment) @@ -225,11 +251,11 @@ class SD21UNCLIP(BaseModel): if unclip_conditioning is None: return torch.zeros((1, self.adm_channels)) else: - return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: - return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] + return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: return args["pooled_output"] diff --git a/comfy/model_management.py b/comfy/model_management.py index 6d6800e86..38bccd2df 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -176,7 +176,7 @@ try: if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported(): + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: VAE_DTYPE = torch.bfloat16 if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index ae687ae39..6b99cec9f 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -166,6 +166,26 @@ class ConditioningSetAreaPercentage: c.append(n) return (c, ) +class ConditioningSetAreaStrength: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, strength): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['strength'] = strength + c.append(n) + return (c, ) + + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): @@ -341,6 +361,62 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class InpaintModelConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + "mask": ("MASK", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, positive, negative, pixels, vae, mask): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + + orig_pixels = pixels + pixels = orig_pixels.clone() + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + + m = (1.0 - mask.round()).squeeze(1) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= m + pixels[:,:,:,i] += 0.5 + concat_latent = vae.encode(pixels) + orig_latent = vae.encode(orig_pixels) + + out_latent = {} + + out_latent["samples"] = orig_latent + out_latent["noise_mask"] = mask + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + d["concat_mask"] = mask + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + + class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -1400,6 +1476,8 @@ class LoadImage: output_masks = [] for i in ImageSequence.Iterator(img): i = ImageOps.exif_transpose(i) + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1455,6 +1533,8 @@ class LoadImageMask: i = Image.open(image_path) i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) i = i.convert("RGBA") mask = None c = channel[0].upper() @@ -1609,10 +1689,11 @@ class ImagePadForOutpaint: def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() - new_image = torch.zeros( + new_image = torch.ones( (d1, d2 + top + bottom, d3 + left + right, d4), dtype=torch.float32, - ) + ) * 0.5 + new_image[:, top:top + d2, left:left + d3, :] = image mask = torch.ones( @@ -1678,6 +1759,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, + "ConditioningSetAreaStrength": ConditioningSetAreaStrength, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, @@ -1704,6 +1786,7 @@ NODE_CLASS_MAPPINGS = { "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, + "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, diff --git a/comfy/ops.py b/comfy/ops.py index f6f85de60..74d04f183 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,10 +1,9 @@ import torch -from contextlib import contextmanager -import comfy.model_management +from . import model_management def cast_bias_weight(s, input): bias = None - non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + non_blocking = model_management.device_supports_non_blocking(input.device) if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) diff --git a/comfy/sample.py b/comfy/sample.py index 565ea662d..158607075 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None): def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) diff --git a/comfy/sd.py b/comfy/sd.py index c51172a46..67fffede8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -531,7 +531,14 @@ def load_unet(unet_path): raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip, vae, metadata=None): - model_management.load_models_gpu([model, clip.load_model()]) - sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): + clip_sd = None + load_models = [model] + if clip is not None: + load_models.append(clip.load_model()) + clip_sd = clip.get_sd() + + model_management.load_models_gpu(load_models) + clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None + sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6722eb83f..1fcdef4d5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -7,7 +7,6 @@ import traceback import zipfile from . import model_management from pkg_resources import resource_filename -import contextlib from . import clip_model import json diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 49087d23e..5baf4bca6 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -65,6 +65,12 @@ class BASE: replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def process_clip_vision_state_dict_for_saving(self, state_dict): + replace_prefix = {} + if self.clip_vision_prefix is not None: + replace_prefix[""] = self.clip_vision_prefix + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model.diffusion_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) diff --git a/comfy/web/extensions/core/groupNode.js b/comfy/web/extensions/core/groupNode.js index 4cf1f7621..0f041fcd2 100644 --- a/comfy/web/extensions/core/groupNode.js +++ b/comfy/web/extensions/core/groupNode.js @@ -1,6 +1,7 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; import { mergeIfValid } from "./widgetInputs.js"; +import { ManageGroupDialog } from "./groupNodeManage.js"; const GROUP = Symbol(); @@ -61,11 +62,7 @@ class GroupNodeBuilder { ); return; case Workflow.InUse.Registered: - if ( - !confirm( - "An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?" - ) - ) { + if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) { return; } break; @@ -151,6 +148,8 @@ export class GroupNodeConfig { this.primitiveDefs = {}; this.widgetToPrimitive = {}; this.primitiveToWidget = {}; + this.nodeInputs = {}; + this.outputVisibility = []; } async registerType(source = "workflow") { @@ -158,6 +157,7 @@ export class GroupNodeConfig { output: [], output_name: [], output_is_list: [], + output_is_hidden: [], name: source + "/" + this.name, display_name: this.name, category: "group nodes" + ("/" + source), @@ -277,8 +277,7 @@ export class GroupNodeConfig { } if (input.widget) { const targetDef = globalDefs[node.type]; - const targetWidget = - targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; const widget = [targetWidget[0], config]; const res = mergeIfValid( @@ -330,7 +329,8 @@ export class GroupNodeConfig { } getInputConfig(node, inputName, seenInputs, config, extra) { - let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName]; + let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; let key = name; let prefix = ""; // Special handling for primitive to include the title if it is set rather than just "value" @@ -349,14 +349,14 @@ export class GroupNodeConfig { } if (config[0] === "IMAGEUPLOAD") { if (!extra) extra = {}; - extra.widget = `${prefix}${config[1]?.widget ?? "image"}`; + extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image"; } if (extra) { config = [config[0], { ...config[1], ...extra }]; } - return { name, config }; + return { name, config, customConfig }; } processWidgetInputs(inputs, node, inputNames, seenInputs) { @@ -366,9 +366,7 @@ export class GroupNodeConfig { for (const inputName of inputNames) { let widgetType = app.getWidgetType(inputs[inputName], inputName); if (widgetType) { - const convertedIndex = node.inputs?.findIndex( - (inp) => inp.name === inputName && inp.widget?.name === inputName - ); + const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName); if (convertedIndex > -1) { // This widget has been converted to a widget // We need to store this in the correct position so link ids line up @@ -424,6 +422,7 @@ export class GroupNodeConfig { } processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) { + this.nodeInputs[node.index] = {}; for (let i = 0; i < slots.length; i++) { const inputName = slots[i]; if (linksTo[i]) { @@ -432,7 +431,11 @@ export class GroupNodeConfig { continue; } - const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + + this.nodeInputs[node.index][inputName] = name; + if(customConfig?.visible === false) continue; + this.nodeDef.input.required[name] = config; inputMap[i] = this.inputCount++; } @@ -452,6 +455,7 @@ export class GroupNodeConfig { const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], { defaultInput: true, }); + this.nodeDef.input.required[name] = config; this.newToOldWidgetMap[name] = { node, inputName }; @@ -477,9 +481,7 @@ export class GroupNodeConfig { this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); // Converted inputs have to be processed after all other nodes as they'll be at the end of the list - this.#convertedToProcess.push(() => - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) - ); + this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)); return inputMapping; } @@ -490,8 +492,12 @@ export class GroupNodeConfig { // Add outputs for (let outputId = 0; outputId < def.output.length; outputId++) { const linksFrom = this.linksFrom[node.index]; - if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) { - // This output is linked internally so we can skip it + // If this output is linked internally we flag it to hide + const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]; + const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId]; + const visible = customConfig?.visible ?? !hasLink; + this.outputVisibility.push(visible); + if (!visible) { continue; } @@ -500,11 +506,15 @@ export class GroupNodeConfig { this.nodeDef.output.push(def.output[outputId]); this.nodeDef.output_is_list.push(def.output_is_list[outputId]); - let label = def.output_name?.[outputId] ?? def.output[outputId]; - const output = node.outputs.find((o) => o.name === label); - if (output?.label) { - label = output.label; + let label = customConfig?.name; + if (!label) { + label = def.output_name?.[outputId] ?? def.output[outputId]; + const output = node.outputs.find((o) => o.name === label); + if (output?.label) { + label = output.label; + } } + let name = label; if (name in seenOutputs) { const prefix = `${node.title ?? node.type} `; @@ -677,6 +687,25 @@ export class GroupNodeHandler { return this.innerNodes; }; + this.node.recreate = async () => { + const id = this.node.id; + const sz = this.node.size; + const nodes = this.node.convertToNodes(); + + const groupNode = LiteGraph.createNode(this.node.type); + groupNode.id = id; + + // Reuse the existing nodes for this instance + groupNode.setInnerNodes(nodes); + groupNode[GROUP].populateWidgets(); + app.graph.add(groupNode); + groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])]; + + // Remove all converted nodes and relink them + groupNode[GROUP].replaceNodes(nodes); + return groupNode; + }; + this.node.convertToNodes = () => { const addInnerNodes = () => { const backup = localStorage.getItem("litegrapheditor_clipboard"); @@ -769,6 +798,7 @@ export class GroupNodeHandler { const slot = node.inputs[groupSlotId]; if (slot.link == null) continue; const link = app.graph.links[slot.link]; + if (!link) continue; // connect this node output to the input of another node const originNode = app.graph.getNodeById(link.origin_id); originNode.connect(link.origin_slot, newNode, +innerInputId); @@ -806,12 +836,23 @@ export class GroupNodeHandler { let optionIndex = options.findIndex((o) => o.content === "Outputs"); if (optionIndex === -1) optionIndex = options.length; else optionIndex++; - options.splice(optionIndex, 0, null, { - content: "Convert to nodes", - callback: () => { - return this.convertToNodes(); + options.splice( + optionIndex, + 0, + null, + { + content: "Convert to nodes", + callback: () => { + return this.convertToNodes(); + }, }, - }); + { + content: "Manage Group Node", + callback: () => { + new ManageGroupDialog(app).show(this.type); + }, + } + ); }; // Draw custom collapse icon to identity this as a group @@ -843,6 +884,7 @@ export class GroupNodeHandler { const r = onDrawForeground?.apply?.(this, arguments); if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) { const n = groupData.nodes[this.runningInternalNodeId]; + if(!n) return; const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`; ctx.save(); ctx.font = "12px sans-serif"; @@ -865,6 +907,28 @@ export class GroupNodeHandler { return onExecutionStart?.apply(this, arguments); }; + const self = this; + const onNodeCreated = this.node.onNodeCreated; + this.node.onNodeCreated = function () { + const config = self.groupData.nodeData.config; + if (config) { + for (const n in config) { + const inputs = config[n]?.input; + for (const w in inputs) { + if (inputs[w].visible !== false) continue; + const widgetName = self.groupData.oldToNewWidgetMap[n][w]; + const widget = this.widgets.find((w) => w.name === widgetName); + if (widget) { + widget.type = "hidden"; + widget.computeSize = () => [0, -4]; + } + } + } + } + + return onNodeCreated?.apply(this, arguments); + }; + function handleEvent(type, getId, getEvent) { const handler = ({ detail }) => { const id = getId(detail); @@ -902,6 +966,26 @@ export class GroupNodeHandler { api.removeEventListener("executing", executing); api.removeEventListener("executed", executed); }; + + this.node.refreshComboInNode = (defs) => { + // Update combo widget options + for (const widgetName in this.groupData.newToOldWidgetMap) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget?.type === "combo") { + const old = this.groupData.newToOldWidgetMap[widgetName]; + const def = defs[old.node.type]; + const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName]; + if (!input) continue; + + widget.options.values = input[0]; + + if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) { + widget.value = widget.options.values[0]; + widget.callback(widget.value); + } + } + } + }; } updateInnerWidgets() { @@ -927,13 +1011,15 @@ export class GroupNodeHandler { continue; } else if (innerNode.type === "Reroute") { const rerouteLinks = this.groupData.linksFrom[old.node.index]; - for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { - const node = this.innerNodes[targetNodeId]; - const input = node.inputs[targetSlot]; - if (input.widget) { - const widget = node.widgets?.find((w) => w.name === input.widget.name); - if (widget) { - widget.value = newValue; + if (rerouteLinks) { + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } } } } @@ -975,7 +1061,7 @@ export class GroupNodeHandler { const [, , targetNodeId, targetNodeSlot] = link; const targetNode = this.groupData.nodeData.nodes[targetNodeId]; const inputs = targetNode.inputs; - const targetWidget = inputs?.[targetNodeSlot].widget; + const targetWidget = inputs?.[targetNodeSlot]?.widget; if (!targetWidget) return; const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); @@ -983,13 +1069,12 @@ export class GroupNodeHandler { if (v == null) return; const widgetName = Object.values(map)[0]; - const widget = this.node.widgets.find(w => w.name === widgetName); - if(widget) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget) { widget.value = v; } } - populateWidgets() { if (!this.node.widgets) return; @@ -1080,7 +1165,7 @@ export class GroupNodeHandler { } static getGroupData(node) { - return node.constructor?.nodeData?.[GROUP]; + return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP]; } static isGroupNode(node) { @@ -1112,7 +1197,7 @@ export class GroupNodeHandler { } function addConvertToGroupOptions() { - function addOption(options, index) { + function addConvertOption(options, index) { const selected = Object.values(app.canvas.selected_nodes ?? {}); const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n)); options.splice(index + 1, null, { @@ -1124,12 +1209,25 @@ function addConvertToGroupOptions() { }); } + function addManageOption(options, index) { + const groups = app.graph.extra?.groupNodes; + const disabled = !groups || !Object.keys(groups).length; + options.splice(index + 1, null, { + content: `Manage Group Nodes`, + disabled, + callback: () => { + new ManageGroupDialog(app).show(); + }, + }); + } + // Add to canvas const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; LGraphCanvas.prototype.getCanvasMenuOptions = function () { const options = getCanvasMenuOptions.apply(this, arguments); const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length; - addOption(options, index); + addConvertOption(options, index); + addManageOption(options, index + 1); return options; }; @@ -1139,7 +1237,7 @@ function addConvertToGroupOptions() { const options = getNodeMenuOptions.apply(this, arguments); if (!GroupNodeHandler.isGroupNode(node)) { const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1; - addOption(options, index); + addConvertOption(options, index); } return options; }; @@ -1167,6 +1265,14 @@ const ext = { node[GROUP] = new GroupNodeHandler(node); } }, + async refreshComboInNodes(defs) { + // Re-register group nodes so new ones are created with the correct options + Object.assign(globalDefs, defs); + const nodes = app.graph.extra?.groupNodes; + if (nodes) { + await GroupNodeConfig.registerFromWorkflow(nodes, {}); + } + } }; app.registerExtension(ext); diff --git a/comfy/web/extensions/core/groupNodeManage.css b/comfy/web/extensions/core/groupNodeManage.css new file mode 100644 index 000000000..5ac89aee3 --- /dev/null +++ b/comfy/web/extensions/core/groupNodeManage.css @@ -0,0 +1,149 @@ +.comfy-group-manage { + background: var(--bg-color); + color: var(--fg-color); + padding: 0; + font-family: Arial, Helvetica, sans-serif; + border-color: black; + margin: 20vh auto; + max-height: 60vh; +} +.comfy-group-manage-outer { + max-height: 60vh; + min-width: 500px; + display: flex; + flex-direction: column; +} +.comfy-group-manage-outer > header { + display: flex; + align-items: center; + gap: 10px; + justify-content: space-between; + background: var(--comfy-menu-bg); + padding: 15px 20px; +} +.comfy-group-manage-outer > header select { + background: var(--comfy-input-bg); + border: 1px solid var(--border-color); + color: var(--input-text); + padding: 5px 10px; + border-radius: 5px; +} +.comfy-group-manage h2 { + margin: 0; + font-weight: normal; +} +.comfy-group-manage main { + display: flex; + overflow: hidden; +} +.comfy-group-manage .drag-handle { + font-weight: bold; +} +.comfy-group-manage-list { + border-right: 1px solid var(--comfy-menu-bg); +} +.comfy-group-manage-list ul { + margin: 40px 0 0; + padding: 0; + list-style: none; +} +.comfy-group-manage-list-items { + max-height: 70vh; + overflow-y: scroll; + overflow-x: hidden; +} +.comfy-group-manage-list li { + display: flex; + padding: 10px 20px 10px 10px; + cursor: pointer; + align-items: center; + gap: 5px; +} +.comfy-group-manage-list div { + display: flex; + flex-direction: column; +} +.comfy-group-manage-list li:not(.selected):hover div { + text-decoration: underline; +} +.comfy-group-manage-list li.selected { + background: var(--border-color); +} +.comfy-group-manage-list li span { + opacity: 0.7; + font-size: smaller; +} +.comfy-group-manage-node { + flex: auto; + background: var(--border-color); + display: flex; + flex-direction: column; +} +.comfy-group-manage-node > div { + overflow: auto; +} +.comfy-group-manage-node header { + display: flex; + background: var(--bg-color); + height: 40px; +} +.comfy-group-manage-node header a { + text-align: center; + flex: auto; + border-right: 1px solid var(--comfy-menu-bg); + border-bottom: 1px solid var(--comfy-menu-bg); + padding: 10px; + cursor: pointer; + font-size: 15px; +} +.comfy-group-manage-node header a:last-child { + border-right: none; +} +.comfy-group-manage-node header a:not(.active):hover { + text-decoration: underline; +} +.comfy-group-manage-node header a.active { + background: var(--border-color); + border-bottom: none; +} +.comfy-group-manage-node-page { + display: none; + overflow: auto; +} +.comfy-group-manage-node-page.active { + display: block; +} +.comfy-group-manage-node-page div { + padding: 10px; + display: flex; + align-items: center; + gap: 10px; +} +.comfy-group-manage-node-page input { + border: none; + color: var(--input-text); + background: var(--comfy-input-bg); + padding: 5px 10px; +} +.comfy-group-manage-node-page input[type="text"] { + flex: auto; +} +.comfy-group-manage-node-page label { + display: flex; + gap: 5px; + align-items: center; +} +.comfy-group-manage footer { + border-top: 1px solid var(--comfy-menu-bg); + padding: 10px; + display: flex; + gap: 10px; +} +.comfy-group-manage footer button { + font-size: 14px; + padding: 5px 10px; + border-radius: 0; +} +.comfy-group-manage footer button:first-child { + margin-right: auto; +} diff --git a/comfy/web/extensions/core/groupNodeManage.js b/comfy/web/extensions/core/groupNodeManage.js new file mode 100644 index 000000000..1ab338386 --- /dev/null +++ b/comfy/web/extensions/core/groupNodeManage.js @@ -0,0 +1,422 @@ +import { $el, ComfyDialog } from "../../scripts/ui.js"; +import { DraggableList } from "../../scripts/ui/draggableList.js"; +import { addStylesheet } from "../../scripts/utils.js"; +import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; + +addStylesheet(import.meta.url); + +const ORDER = Symbol(); + +function merge(target, source) { + if (typeof target === "object" && typeof source === "object") { + for (const key in source) { + const sv = source[key]; + if (typeof sv === "object") { + let tv = target[key]; + if (!tv) tv = target[key] = {}; + merge(tv, source[key]); + } else { + target[key] = sv; + } + } + } + + return target; +} + +export class ManageGroupDialog extends ComfyDialog { + /** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */ + tabs = {}; + /** @type { number | null | undefined } */ + selectedNodeIndex; + /** @type { keyof ManageGroupDialog["tabs"] } */ + selectedTab = "Inputs"; + /** @type { string | undefined } */ + selectedGroup; + + /** @type { Record>> } */ + modifications = {}; + + get selectedNodeInnerIndex() { + return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex; + } + + constructor(app) { + super(); + this.app = app; + this.element = $el("dialog.comfy-group-manage", { + parent: document.body, + }); + } + + changeTab(tab) { + this.tabs[this.selectedTab].tab.classList.remove("active"); + this.tabs[this.selectedTab].page.classList.remove("active"); + this.tabs[tab].tab.classList.add("active"); + this.tabs[tab].page.classList.add("active"); + this.selectedTab = tab; + } + + changeNode(index, force) { + if (!force && this.selectedNodeIndex === index) return; + + if (this.selectedNodeIndex != null) { + this.nodeItems[this.selectedNodeIndex].classList.remove("selected"); + } + this.nodeItems[index].classList.add("selected"); + this.selectedNodeIndex = index; + + if (!this.buildInputsPage() && this.selectedTab === "Inputs") { + this.changeTab("Widgets"); + } + if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") { + this.changeTab("Outputs"); + } + if (!this.buildOutputsPage() && this.selectedTab === "Outputs") { + this.changeTab("Inputs"); + } + + this.changeTab(this.selectedTab); + } + + getGroupData() { + this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup]; + this.groupNodeDef = this.groupNodeType.nodeData; + this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType); + } + + changeGroup(group, reset = true) { + this.selectedGroup = group; + this.getGroupData(); + + const nodes = this.groupData.nodeData.nodes; + this.nodeItems = nodes.map((n, i) => + $el( + "li.draggable-item", + { + dataset: { + nodeindex: n.index + "", + }, + onclick: () => { + this.changeNode(i); + }, + }, + [ + $el("span.drag-handle"), + $el( + "div", + { + textContent: n.title ?? n.type, + }, + n.title + ? $el("span", { + textContent: n.type, + }) + : [] + ), + ] + ) + ); + + this.innerNodesList.replaceChildren(...this.nodeItems); + + if (reset) { + this.selectedNodeIndex = null; + this.changeNode(0); + } else { + const items = this.draggable.getAllItems(); + let index = items.findIndex(item => item.classList.contains("selected")); + if(index === -1) index = this.selectedNodeIndex; + this.changeNode(index, true); + } + + const ordered = [...nodes]; + this.draggable?.dispose(); + this.draggable = new DraggableList(this.innerNodesList, "li"); + this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => { + if (oldPosition === newPosition) return; + ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]); + for (let i = 0; i < ordered.length; i++) { + this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i }); + } + }); + } + + storeModification({ nodeIndex, section, prop, value }) { + const groupMod = (this.modifications[this.selectedGroup] ??= {}); + const nodesMod = (groupMod.nodes ??= {}); + const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {}); + const typeMod = (nodeMod[section] ??= {}); + if (typeof value === "object") { + const objMod = (typeMod[prop] ??= {}); + Object.assign(objMod, value); + } else { + typeMod[prop] = value; + } + } + + getEditElement(section, prop, value, placeholder, checked, checkable = true) { + if (value === placeholder) value = ""; + + const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop]; + if (mods) { + if (mods.name != null) { + value = mods.name; + } + if (mods.visible != null) { + checked = mods.visible; + } + } + + return $el("div", [ + $el("input", { + value, + placeholder, + type: "text", + onchange: (e) => { + this.storeModification({ section, prop, value: { name: e.target.value } }); + }, + }), + $el("label", { textContent: "Visible" }, [ + $el("input", { + type: "checkbox", + checked, + disabled: !checkable, + onchange: (e) => { + this.storeModification({ section, prop, value: { visible: !!e.target.checked } }); + }, + }), + ]), + ]); + } + + buildWidgetsPage() { + const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex]; + const items = Object.keys(widgets ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.widgetsPage.replaceChildren( + ...items.map((oldName) => { + return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false); + }) + ); + return !!items.length; + } + + buildInputsPage() { + const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex]; + const items = Object.keys(inputs ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.inputsPage.replaceChildren( + ...items + .map((oldName) => { + let value = inputs[oldName]; + if (!value) { + return; + } + + return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false); + }) + .filter(Boolean) + ); + return !!items.length; + } + + buildOutputsPage() { + const nodes = this.groupData.nodeData.nodes; + const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]); + const outputs = innerNodeDef?.output ?? []; + const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex]; + + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.output; + const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex]; + const checkable = node.type !== "PrimitiveNode"; + this.outputsPage.replaceChildren( + ...outputs + .map((type, slot) => { + const groupOutputIndex = groupOutputs?.[slot]; + const oldName = innerNodeDef.output_name?.[slot] ?? type; + let value = config?.[slot]?.name; + const visible = config?.[slot]?.visible || groupOutputIndex != null; + if (!value || value === oldName) { + value = ""; + } + return this.getEditElement("output", slot, value, oldName, visible, checkable); + }) + .filter(Boolean) + ); + return !!outputs.length; + } + + show(type) { + const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b)); + + this.innerNodesList = $el("ul.comfy-group-manage-list-items"); + this.widgetsPage = $el("section.comfy-group-manage-node-page"); + this.inputsPage = $el("section.comfy-group-manage-node-page"); + this.outputsPage = $el("section.comfy-group-manage-node-page"); + const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]); + + this.tabs = [ + ["Inputs", this.inputsPage], + ["Widgets", this.widgetsPage], + ["Outputs", this.outputsPage], + ].reduce((p, [name, page]) => { + p[name] = { + tab: $el("a", { + onclick: () => { + this.changeTab(name); + }, + textContent: name, + }), + page, + }; + return p; + }, {}); + + const outer = $el("div.comfy-group-manage-outer", [ + $el("header", [ + $el("h2", "Group Nodes"), + $el( + "select", + { + onchange: (e) => { + this.changeGroup(e.target.value); + }, + }, + groupNodes.map((g) => + $el("option", { + textContent: g, + selected: "workflow/" + g === type, + value: g, + }) + ) + ), + ]), + $el("main", [ + $el("section.comfy-group-manage-list", this.innerNodesList), + $el("section.comfy-group-manage-node", [ + $el( + "header", + Object.values(this.tabs).map((t) => t.tab) + ), + pages, + ]), + ]), + $el("footer", [ + $el( + "button.comfy-btn", + { + onclick: (e) => { + const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup); + if (node) { + alert("This group node is in use in the current workflow, please first remove these."); + return; + } + if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) { + delete app.graph.extra.groupNodes[this.selectedGroup]; + LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup); + } + this.show(); + }, + }, + "Delete Group Node" + ), + $el( + "button.comfy-btn", + { + onclick: async () => { + let nodesByType; + let recreateNodes = []; + const types = {}; + for (const g in this.modifications) { + const type = app.graph.extra.groupNodes[g]; + let config = (type.config ??= {}); + + let nodeMods = this.modifications[g]?.nodes; + if (nodeMods) { + const keys = Object.keys(nodeMods); + if (nodeMods[keys[0]][ORDER]) { + // If any node is reordered, they will all need sequencing + const orderedNodes = []; + const orderedMods = {}; + const orderedConfig = {}; + + for (const n of keys) { + const order = nodeMods[n][ORDER].order; + orderedNodes[order] = type.nodes[+n]; + orderedMods[order] = nodeMods[n]; + orderedNodes[order].index = order; + } + + // Rewrite links + for (const l of type.links) { + if (l[0] != null) l[0] = type.nodes[l[0]].index; + if (l[2] != null) l[2] = type.nodes[l[2]].index; + } + + // Rewrite externals + if (type.external) { + for (const ext of type.external) { + ext[0] = type.nodes[ext[0]]; + } + } + + // Rewrite modifications + for (const id of keys) { + if (config[id]) { + orderedConfig[type.nodes[id].index] = config[id]; + } + delete config[id]; + } + + type.nodes = orderedNodes; + nodeMods = orderedMods; + type.config = config = orderedConfig; + } + + merge(config, nodeMods); + } + + types[g] = type; + + if (!nodesByType) { + nodesByType = app.graph._nodes.reduce((p, n) => { + p[n.type] ??= []; + p[n.type].push(n); + return p; + }, {}); + } + + const nodes = nodesByType["workflow/" + g]; + if (nodes) recreateNodes.push(...nodes); + } + + await GroupNodeConfig.registerFromWorkflow(types, {}); + + for (const node of recreateNodes) { + node.recreate(); + } + + this.modifications = {}; + this.app.graph.setDirtyCanvas(true, true); + this.changeGroup(this.selectedGroup, false); + }, + }, + "Save" + ), + $el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"), + ]), + ]); + + this.element.replaceChildren(outer); + this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]); + this.element.showModal(); + + this.element.addEventListener("close", () => { + this.draggable?.dispose(); + }); + } +} \ No newline at end of file diff --git a/comfy/web/extensions/core/nodeTemplates.js b/comfy/web/extensions/core/nodeTemplates.js index bc9a10864..9350ba654 100644 --- a/comfy/web/extensions/core/nodeTemplates.js +++ b/comfy/web/extensions/core/nodeTemplates.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; import { ComfyDialog, $el } from "../../scripts/ui.js"; import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; @@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; // Open the manage dialog and Drag and drop elements using the "Name:" label as handle const id = "Comfy.NodeTemplates"; +const file = "comfy.templates.json"; class ManageTemplates extends ComfyDialog { constructor() { super(); + this.load().then((v) => { + this.templates = v; + }); + this.element.classList.add("comfy-manage-templates"); - this.templates = this.load(); this.draggedEl = null; this.saveVisualCue = null; this.emptyImg = new Image(); - this.emptyImg.src = ''; + this.emptyImg.src = ""; this.importInput = $el("input", { type: "file", @@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog { return btns; } - load() { - const templates = localStorage.getItem(id); - if (templates) { - return JSON.parse(templates); + async load() { + let templates = []; + if (app.storageLocation === "server") { + if (app.isNewUserSession) { + // New user so migrate existing templates + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } + await api.storeUserData(file, json, { stringify: false }); + } else { + const res = await api.getUserData(file); + if (res.status === 200) { + try { + templates = await res.json(); + } catch (error) { + } + } else if (res.status !== 404) { + console.error(res.status + " " + res.statusText); + } + } } else { - return []; + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } } + + return templates ?? []; } - store() { - localStorage.setItem(id, JSON.stringify(this.templates)); + async store() { + if(app.storageLocation === "server") { + const templates = JSON.stringify(this.templates, undefined, 4); + localStorage.setItem(id, templates); // Backwards compatibility + try { + await api.storeUserData(file, templates, { stringify: false }); + } catch (error) { + console.error(error); + alert(error.message); + } + } else { + localStorage.setItem(id, JSON.stringify(this.templates)); + } } async importAll() { @@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog { if (file.type === "application/json" || file.name.endsWith(".json")) { const reader = new FileReader(); reader.onload = async () => { - var importFile = JSON.parse(reader.result); - if (importFile && importFile?.templates) { + const importFile = JSON.parse(reader.result); + if (importFile?.templates) { for (const template of importFile.templates) { if (template?.name && template?.data) { this.templates.push(template); } } - this.store(); + await this.store(); } }; await reader.readAsText(file); @@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog { e.currentTarget.style.border = "1px dashed transparent"; e.currentTarget.removeAttribute("draggable"); - // rearrange the elements in the localStorage + // rearrange the elements this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { var prev_i = el.dataset.id; diff --git a/comfy/web/extensions/core/undoRedo.js b/comfy/web/extensions/core/undoRedo.js index 3cb137520..900eed2a7 100644 --- a/comfy/web/extensions/core/undoRedo.js +++ b/comfy/web/extensions/core/undoRedo.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js" const MAX_HISTORY = 50; @@ -15,6 +16,7 @@ function checkState() { } activeState = clone(currentState); redo.length = 0; + api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState })); } } @@ -92,7 +94,7 @@ const undoRedo = async (e) => { }; const bindInput = (activeEl) => { - if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") { for (const evt of ["change", "input", "blur"]) { if (`on${evt}` in activeEl) { const listener = () => { @@ -106,15 +108,23 @@ const bindInput = (activeEl) => { } }; +let keyIgnored = false; window.addEventListener( "keydown", (e) => { requestAnimationFrame(async () => { - const activeEl = document.activeElement; - if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { - // Ignore events on inputs, they have their native history - return; + let activeEl; + // If we are auto queue in change mode then we do want to trigger on inputs + if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") { + activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } } + + keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta"; + if (keyIgnored) return; // Check if this is a ctrl+z ctrl+y if (await undoRedo(e)) return; @@ -127,11 +137,23 @@ window.addEventListener( true ); +window.addEventListener("keyup", (e) => { + if (keyIgnored) { + keyIgnored = false; + checkState(); + } +}); + // Handle clicking DOM elements (e.g. widgets) window.addEventListener("mouseup", () => { checkState(); }); +// Handle prompt queue event for dynamic widget changes +api.addEventListener("promptQueued", () => { + checkState(); +}); + // Handle litegraph clicks const processMouseUp = LGraphCanvas.prototype.processMouseUp; LGraphCanvas.prototype.processMouseUp = function (e) { @@ -145,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) { checkState(); return v; }; + +// Handle litegraph context menu for COMBO widgets +const close = LiteGraph.ContextMenu.prototype.close; +LiteGraph.ContextMenu.prototype.close = function(e) { + const v = close.apply(this, arguments); + checkState(); + return v; +} \ No newline at end of file diff --git a/comfy/web/index.html b/comfy/web/index.html index 41bc246c0..094db9d15 100644 --- a/comfy/web/index.html +++ b/comfy/web/index.html @@ -16,5 +16,33 @@ window.graph = app.graph; - + + + diff --git a/comfy/web/jsconfig.json b/comfy/web/jsconfig.json index 57403d8cf..b65fa2746 100644 --- a/comfy/web/jsconfig.json +++ b/comfy/web/jsconfig.json @@ -3,7 +3,8 @@ "baseUrl": ".", "paths": { "/*": ["./*"] - } + }, + "lib": ["DOM", "ES2022"] }, "include": ["."] } diff --git a/comfy/web/lib/litegraph.core.js b/comfy/web/lib/litegraph.core.js index 434c4a83b..080e0ef47 100644 --- a/comfy/web/lib/litegraph.core.js +++ b/comfy/web/lib/litegraph.core.js @@ -11496,7 +11496,7 @@ LGraphNode.prototype.executeAction = function(action) } timeout_close = setTimeout(function() { dialog.close(); - }, 500); + }, typeof options.hide_on_mouse_leave === "number" ? options.hide_on_mouse_leave : 500); }); // if filtering, check focus changed to comboboxes and prevent closing if (options.do_type_filter){ diff --git a/comfy/web/scripts/api.js b/comfy/web/scripts/api.js index 9aa7528af..3a9bcc87a 100644 --- a/comfy/web/scripts/api.js +++ b/comfy/web/scripts/api.js @@ -12,6 +12,13 @@ class ComfyApi extends EventTarget { } fetchApi(route, options) { + if (!options) { + options = {}; + } + if (!options.headers) { + options.headers = {}; + } + options.headers["Comfy-User"] = this.user; return fetch(this.apiURL(route), options); } @@ -315,6 +322,99 @@ class ComfyApi extends EventTarget { async interrupt() { await this.#postItem("interrupt", null); } + + /** + * Gets user configuration data and where data should be stored + * @returns { Promise<{ storage: "server" | "browser", users?: Promise, migrated?: boolean }> } + */ + async getUserConfig() { + return (await this.fetchApi("/users")).json(); + } + + /** + * Creates a new user + * @param { string } username + * @returns The fetch response + */ + createUser(username) { + return this.fetchApi("/users", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ username }), + }); + } + + /** + * Gets all setting values for the current user + * @returns { Promise } A dictionary of id -> value + */ + async getSettings() { + return (await this.fetchApi("/settings")).json(); + } + + /** + * Gets a setting for the current user + * @param { string } id The id of the setting to fetch + * @returns { Promise } The setting value + */ + async getSetting(id) { + return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json(); + } + + /** + * Stores a dictionary of settings for the current user + * @param { Record } settings Dictionary of setting id -> value to save + * @returns { Promise } + */ + async storeSettings(settings) { + return this.fetchApi(`/settings`, { + method: "POST", + body: JSON.stringify(settings) + }); + } + + /** + * Stores a setting for the current user + * @param { string } id The id of the setting to update + * @param { unknown } value The value of the setting + * @returns { Promise } + */ + async storeSetting(id, value) { + return this.fetchApi(`/settings/${encodeURIComponent(id)}`, { + method: "POST", + body: JSON.stringify(value) + }); + } + + /** + * Gets a user data file for the current user + * @param { string } file The name of the userdata file to load + * @param { RequestInit } [options] + * @returns { Promise } The fetch response object + */ + async getUserData(file, options) { + return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options); + } + + /** + * Stores a user data file for the current user + * @param { string } file The name of the userdata file to save + * @param { unknown } data The data to save to the file + * @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options] + * @returns { Promise } + */ + async storeUserData(file, data, options = { stringify: true, throwOnError: true }) { + const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, { + method: "POST", + body: options?.stringify ? JSON.stringify(data) : data, + ...options, + }); + if (resp.status !== 200) { + throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`); + } + } } export const api = new ComfyApi(); diff --git a/comfy/web/scripts/app.js b/comfy/web/scripts/app.js index 7353f5a3b..6df393ba6 100644 --- a/comfy/web/scripts/app.js +++ b/comfy/web/scripts/app.js @@ -1,5 +1,5 @@ import { ComfyLogging } from "./logging.js"; -import { ComfyWidgets } from "./widgets.js"; +import { ComfyWidgets, initWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; @@ -269,6 +269,71 @@ export class ComfyApp { * @param {*} node The node to add the menu handler */ #addNodeContextMenuHandler(node) { + function getCopyImageOption(img) { + if (typeof window.ClipboardItem === "undefined") return []; + return [ + { + content: "Copy Image", + callback: async () => { + const url = new URL(img.src); + url.searchParams.delete("preview"); + + const writeImage = async (blob) => { + await navigator.clipboard.write([ + new ClipboardItem({ + [blob.type]: blob, + }), + ]); + }; + + try { + const data = await fetch(url); + const blob = await data.blob(); + try { + await writeImage(blob); + } catch (error) { + // Chrome seems to only support PNG on write, convert and try again + if (blob.type !== "image/png") { + const canvas = $el("canvas", { + width: img.naturalWidth, + height: img.naturalHeight, + }); + const ctx = canvas.getContext("2d"); + let image; + if (typeof window.createImageBitmap === "undefined") { + image = new Image(); + const p = new Promise((resolve, reject) => { + image.onload = resolve; + image.onerror = reject; + }).finally(() => { + URL.revokeObjectURL(image.src); + }); + image.src = URL.createObjectURL(blob); + await p; + } else { + image = await createImageBitmap(blob); + } + try { + ctx.drawImage(image, 0, 0); + canvas.toBlob(writeImage, "image/png"); + } finally { + if (typeof image.close === "function") { + image.close(); + } + } + + return; + } + throw error; + } + } catch (error) { + alert("Error copying image: " + (error.message ?? error)); + } + }, + }, + ]; + } + node.prototype.getExtraMenuOptions = function (_, options) { if (this.imgs) { // If this node has images then we add an open in new tab item @@ -286,16 +351,17 @@ export class ComfyApp { content: "Open Image", callback: () => { let url = new URL(img.src); - url.searchParams.delete('preview'); - window.open(url, "_blank") + url.searchParams.delete("preview"); + window.open(url, "_blank"); }, }, + ...getCopyImageOption(img), { content: "Save Image", callback: () => { const a = document.createElement("a"); let url = new URL(img.src); - url.searchParams.delete('preview'); + url.searchParams.delete("preview"); a.href = url; a.setAttribute("download", new URLSearchParams(url.search).get("filename")); document.body.append(a); @@ -308,33 +374,41 @@ export class ComfyApp { } options.push({ - content: "Bypass", - callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); } - }); + content: "Bypass", + callback: (obj) => { + if (this.mode === 4) this.mode = 0; + else this.mode = 4; + this.graph.change(); + }, + }); // prevent conflict of clipspace content - if(!ComfyApp.clipspace_return_node) { + if (!ComfyApp.clipspace_return_node) { options.push({ - content: "Copy (Clipspace)", - callback: (obj) => { ComfyApp.copyToClipspace(this); } - }); + content: "Copy (Clipspace)", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + }, + }); - if(ComfyApp.clipspace != null) { + if (ComfyApp.clipspace != null) { options.push({ - content: "Paste (Clipspace)", - callback: () => { ComfyApp.pasteFromClipspace(this); } - }); + content: "Paste (Clipspace)", + callback: () => { + ComfyApp.pasteFromClipspace(this); + }, + }); } - if(ComfyApp.isImageNode(this)) { + if (ComfyApp.isImageNode(this)) { options.push({ - content: "Open in MaskEditor", - callback: (obj) => { - ComfyApp.copyToClipspace(this); - ComfyApp.clipspace_return_node = this; - ComfyApp.open_maskeditor(); - } - }); + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); + }, + }); } } }; @@ -1291,10 +1365,92 @@ export class ComfyApp { await Promise.all(extensionPromises); } + async #migrateSettings() { + this.isNewUserSession = true; + // Store all current settings + const settings = Object.keys(this.ui.settings).reduce((p, n) => { + const v = localStorage[`Comfy.Settings.${n}`]; + if (v) { + try { + p[n] = JSON.parse(v); + } catch (error) {} + } + return p; + }, {}); + + await api.storeSettings(settings); + } + + async #setUser() { + const userConfig = await api.getUserConfig(); + this.storageLocation = userConfig.storage; + if (typeof userConfig.migrated == "boolean") { + // Single user mode migrated true/false for if the default user is created + if (!userConfig.migrated && this.storageLocation === "server") { + // Default user not created yet + await this.#migrateSettings(); + } + return; + } + + this.multiUserServer = true; + let user = localStorage["Comfy.userId"]; + const users = userConfig.users ?? {}; + if (!user || !users[user]) { + // This will rarely be hit so move the loading to on demand + const { UserSelectionScreen } = await import("./ui/userSelection.js"); + + this.ui.menuContainer.style.display = "none"; + const { userId, username, created } = await new UserSelectionScreen().show(users, user); + this.ui.menuContainer.style.display = ""; + + user = userId; + localStorage["Comfy.userName"] = username; + localStorage["Comfy.userId"] = user; + + if (created) { + api.user = user; + await this.#migrateSettings(); + } + } + + api.user = user; + + this.ui.settings.addSetting({ + id: "Comfy.SwitchUser", + name: "Switch User", + type: (name) => { + let currentUser = localStorage["Comfy.userName"]; + if (currentUser) { + currentUser = ` (${currentUser})`; + } + return $el("tr", [ + $el("td", [ + $el("label", { + textContent: name, + }), + ]), + $el("td", [ + $el("button", { + textContent: name + (currentUser ?? ""), + onclick: () => { + delete localStorage["Comfy.userId"]; + delete localStorage["Comfy.userName"]; + window.location.reload(); + }, + }), + ]), + ]); + }, + }); + } + /** * Set up the app on the page */ async setup() { + await this.#setUser(); + await this.ui.settings.load(); await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM @@ -1338,6 +1494,7 @@ export class ComfyApp { await this.#invokeExtensionsAsync("init"); await this.registerNodes(); + initWidgets(this); // Load previous workflow let restored = false; @@ -1692,6 +1849,14 @@ export class ComfyApp { */ async graphToPrompt() { for (const outerNode of this.graph.computeExecutionOrder(false)) { + if (outerNode.widgets) { + for (const widget of outerNode.widgets) { + // Allow widgets to run callbacks before a prompt has been queued + // e.g. random seed before every gen + widget.beforeQueued?.(); + } + } + const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; for (const node of innerNodes) { if (node.isVirtualNode) { @@ -1903,6 +2068,7 @@ export class ComfyApp { } finally { this.#processingQueue = false; } + api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } })); } /** @@ -2046,6 +2212,8 @@ export class ComfyApp { } } } + + await this.#invokeExtensionsAsync("refreshComboInNodes", defs); } /** diff --git a/comfy/web/scripts/domWidget.js b/comfy/web/scripts/domWidget.js index eb0742d38..d5eeebdbd 100644 --- a/comfy/web/scripts/domWidget.js +++ b/comfy/web/scripts/domWidget.js @@ -239,7 +239,8 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { node.flags?.collapsed || (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.computedHeight <= 0 || - widget.type === "converted-widget"; + widget.type === "converted-widget"|| + widget.type === "hidden"; element.hidden = hidden; element.style.display = hidden ? "none" : null; if (hidden) { diff --git a/comfy/web/scripts/logging.js b/comfy/web/scripts/logging.js index c73462e1e..875dd970b 100644 --- a/comfy/web/scripts/logging.js +++ b/comfy/web/scripts/logging.js @@ -269,6 +269,9 @@ export class ComfyLogging { id: settingId, name: settingId, defaultValue: true, + onChange: (value) => { + this.enabled = value; + }, type: (name, setter, value) => { return $el("tr", [ $el("td", [ @@ -283,7 +286,7 @@ export class ComfyLogging { type: "checkbox", checked: value, onchange: (event) => { - setter((this.enabled = event.target.checked)); + setter(event.target.checked); }, }), $el("button", { diff --git a/comfy/web/scripts/ui.js b/comfy/web/scripts/ui.js index ebaf86fe4..d4835c6e4 100644 --- a/comfy/web/scripts/ui.js +++ b/comfy/web/scripts/ui.js @@ -1,5 +1,23 @@ -import {api} from "./api.js"; +import { api } from "./api.js"; +import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js"; +import { toggleSwitch } from "./ui/toggleSwitch.js"; +import { ComfySettingsDialog } from "./ui/settings.js"; +export const ComfyDialog = _ComfyDialog; + +/** + * + * @param { string } tag HTML Element Tag and optional classes e.g. div.class1.class2 + * @param { string | Element | Element[] | { + * parent?: Element, + * $?: (el: Element) => void, + * dataset?: DOMStringMap, + * style?: CSSStyleDeclaration, + * for?: string + * } | undefined } propsOrChildren + * @param { Element[] | undefined } [children] + * @returns + */ export function $el(tag, propsOrChildren, children) { const split = tag.split("."); const element = document.createElement(split.shift()); @@ -8,6 +26,11 @@ export function $el(tag, propsOrChildren, children) { } if (propsOrChildren) { + if (typeof propsOrChildren === "string") { + propsOrChildren = { textContent: propsOrChildren }; + } else if (propsOrChildren instanceof Element) { + propsOrChildren = [propsOrChildren]; + } if (Array.isArray(propsOrChildren)) { element.append(...propsOrChildren); } else { @@ -31,7 +54,7 @@ export function $el(tag, propsOrChildren, children) { Object.assign(element, propsOrChildren); if (children) { - element.append(...children); + element.append(...(children instanceof Array ? children : [children])); } if (parent) { @@ -167,267 +190,6 @@ function dragElement(dragEl, settings) { } } -export class ComfyDialog { - constructor() { - this.element = $el("div.comfy-modal", {parent: document.body}, [ - $el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]), - ]); - } - - createButtons() { - return [ - $el("button", { - type: "button", - textContent: "Close", - onclick: () => this.close(), - }), - ]; - } - - close() { - this.element.style.display = "none"; - } - - show(html) { - if (typeof html === "string") { - this.textElement.innerHTML = html; - } else { - this.textElement.replaceChildren(html); - } - this.element.style.display = "flex"; - } -} - -class ComfySettingsDialog extends ComfyDialog { - constructor() { - super(); - this.element = $el("dialog", { - id: "comfy-settings-dialog", - parent: document.body, - }, [ - $el("table.comfy-modal-content.comfy-table", [ - $el("caption", {textContent: "Settings"}), - $el("tbody", {$: (tbody) => (this.textElement = tbody)}), - $el("button", { - type: "button", - textContent: "Close", - style: { - cursor: "pointer", - }, - onclick: () => { - this.element.close(); - }, - }), - ]), - ]); - this.settings = []; - } - - getSettingValue(id, defaultValue) { - const settingId = "Comfy.Settings." + id; - const v = localStorage[settingId]; - return v == null ? defaultValue : JSON.parse(v); - } - - setSettingValue(id, value) { - const settingId = "Comfy.Settings." + id; - localStorage[settingId] = JSON.stringify(value); - } - - addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) { - if (!id) { - throw new Error("Settings must have an ID"); - } - - if (this.settings.find((s) => s.id === id)) { - throw new Error(`Setting ${id} of type ${type} must have a unique ID.`); - } - - const settingId = `Comfy.Settings.${id}`; - const v = localStorage[settingId]; - let value = v == null ? defaultValue : JSON.parse(v); - - // Trigger initial setting of value - if (onChange) { - onChange(value, undefined); - } - - this.settings.push({ - render: () => { - const setter = (v) => { - if (onChange) { - onChange(v, value); - } - localStorage[settingId] = JSON.stringify(v); - value = v; - }; - value = this.getSettingValue(id, defaultValue); - - let element; - const htmlID = id.replaceAll(".", "-"); - - const labelCell = $el("td", [ - $el("label", { - for: htmlID, - classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""], - textContent: name, - }) - ]); - - if (typeof type === "function") { - element = type(name, setter, value, attrs); - } else { - switch (type) { - case "boolean": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - id: htmlID, - type: "checkbox", - checked: value, - onchange: (event) => { - const isChecked = event.target.checked; - if (onChange !== undefined) { - onChange(isChecked) - } - this.setSettingValue(id, isChecked); - }, - }), - ]), - ]) - break; - case "number": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - type, - value, - id: htmlID, - oninput: (e) => { - setter(e.target.value); - }, - ...attrs - }), - ]), - ]); - break; - case "slider": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("div", { - style: { - display: "grid", - gridAutoFlow: "column", - }, - }, [ - $el("input", { - ...attrs, - value, - type: "range", - oninput: (e) => { - setter(e.target.value); - e.target.nextElementSibling.value = e.target.value; - }, - }), - $el("input", { - ...attrs, - value, - id: htmlID, - type: "number", - style: {maxWidth: "4rem"}, - oninput: (e) => { - setter(e.target.value); - e.target.previousElementSibling.value = e.target.value; - }, - }), - ]), - ]), - ]); - break; - case "combo": - element = $el("tr", [ - labelCell, - $el("td", [ - $el( - "select", - { - oninput: (e) => { - setter(e.target.value); - }, - }, - (typeof options === "function" ? options(value) : options || []).map((opt) => { - if (typeof opt === "string") { - opt = { text: opt }; - } - const v = opt.value ?? opt.text; - return $el("option", { - value: v, - textContent: opt.text, - selected: value + "" === v + "", - }); - }) - ), - ]), - ]); - break; - case "text": - default: - if (type !== "text") { - console.warn(`Unsupported setting type '${type}, defaulting to text`); - } - - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - value, - id: htmlID, - oninput: (e) => { - setter(e.target.value); - }, - ...attrs, - }), - ]), - ]); - break; - } - } - if (tooltip) { - element.title = tooltip; - } - - return element; - }, - }); - - const self = this; - return { - get value() { - return self.getSettingValue(id, defaultValue); - }, - set value(v) { - self.setSettingValue(id, v); - }, - }; - } - - show() { - this.textElement.replaceChildren( - $el("tr", { - style: {display: "none"}, - }, [ - $el("th"), - $el("th", {style: {width: "33%"}}) - ]), - ...this.settings.map((s) => s.render()), - ) - this.element.showModal(); - } -} - class ComfyList { #type; #text; @@ -526,7 +288,7 @@ export class ComfyUI { constructor(app) { this.app = app; this.dialog = new ComfyDialog(); - this.settings = new ComfySettingsDialog(); + this.settings = new ComfySettingsDialog(app); this.batchCount = 1; this.lastQueueSize = 0; @@ -607,6 +369,31 @@ export class ComfyUI { }, }); + const autoQueueModeEl = toggleSwitch( + "autoQueueMode", + [ + { text: "instant", tooltip: "A new prompt will be queued as soon as the queue reaches 0" }, + { text: "change", tooltip: "A new prompt will be queued when the queue is at 0 and the graph is/has changed" }, + ], + { + onChange: (value) => { + this.autoQueueMode = value.item.value; + }, + } + ); + autoQueueModeEl.style.display = "none"; + + api.addEventListener("graphChanged", () => { + if (this.autoQueueMode === "change" && this.autoQueueEnabled === true) { + if (this.lastQueueSize === 0) { + this.graphHasChanged = false; + app.queuePrompt(0, this.batchCount); + } else { + this.graphHasChanged = true; + } + } + }); + this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [ $el("div.drag-handle", { style: { @@ -633,6 +420,7 @@ export class ComfyUI { document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none"; this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1; document.getElementById("autoQueueCheckbox").checked = false; + this.autoQueueEnabled = false; }, }), ]), @@ -664,20 +452,22 @@ export class ComfyUI { }, }), ]), - $el("div",[ $el("label",{ for:"autoQueueCheckbox", innerHTML: "Auto Queue" - // textContent: "Auto Queue" }), $el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "Automatically queue prompt when the queue size hits 0", - + onchange: (e) => { + this.autoQueueEnabled = e.target.checked; + autoQueueModeEl.style.display = this.autoQueueEnabled ? "" : "none"; + } }), + autoQueueModeEl ]) ]), $el("div.comfy-menu-btns", [ @@ -811,10 +601,13 @@ export class ComfyUI { if ( this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && - document.getElementById("autoQueueCheckbox").checked && - ! app.lastExecutionError + this.autoQueueEnabled && + (this.autoQueueMode === "instant" || this.graphHasChanged) && + !app.lastExecutionError ) { app.queuePrompt(0, this.batchCount); + status.exec_info.queue_remaining += this.batchCount; + this.graphHasChanged = false; } this.lastQueueSize = status.exec_info.queue_remaining; } diff --git a/comfy/web/scripts/ui/dialog.js b/comfy/web/scripts/ui/dialog.js new file mode 100644 index 000000000..aee93b3c8 --- /dev/null +++ b/comfy/web/scripts/ui/dialog.js @@ -0,0 +1,32 @@ +import { $el } from "../ui.js"; + +export class ComfyDialog { + constructor() { + this.element = $el("div.comfy-modal", { parent: document.body }, [ + $el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]), + ]); + } + + createButtons() { + return [ + $el("button", { + type: "button", + textContent: "Close", + onclick: () => this.close(), + }), + ]; + } + + close() { + this.element.style.display = "none"; + } + + show(html) { + if (typeof html === "string") { + this.textElement.innerHTML = html; + } else { + this.textElement.replaceChildren(html); + } + this.element.style.display = "flex"; + } +} diff --git a/comfy/web/scripts/ui/draggableList.js b/comfy/web/scripts/ui/draggableList.js new file mode 100644 index 000000000..d53594886 --- /dev/null +++ b/comfy/web/scripts/ui/draggableList.js @@ -0,0 +1,287 @@ +// @ts-check +/* + Original implementation: + https://github.com/TahaSh/drag-to-reorder + MIT License + + Copyright (c) 2023 Taha Shashtari + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +import { $el } from "../ui.js"; + +$el("style", { + parent: document.head, + textContent: ` + .draggable-item { + position: relative; + will-change: transform; + user-select: none; + } + .draggable-item.is-idle { + transition: 0.25s ease transform; + } + .draggable-item.is-draggable { + z-index: 10; + } + ` +}); + +export class DraggableList extends EventTarget { + listContainer; + draggableItem; + pointerStartX; + pointerStartY; + scrollYMax; + itemsGap = 0; + items = []; + itemSelector; + handleClass = "drag-handle"; + off = []; + offDrag = []; + + constructor(element, itemSelector) { + super(); + this.listContainer = element; + this.itemSelector = itemSelector; + + if (!this.listContainer) return; + + this.off.push(this.on(this.listContainer, "mousedown", this.dragStart)); + this.off.push(this.on(this.listContainer, "touchstart", this.dragStart)); + this.off.push(this.on(document, "mouseup", this.dragEnd)); + this.off.push(this.on(document, "touchend", this.dragEnd)); + } + + getAllItems() { + if (!this.items?.length) { + this.items = Array.from(this.listContainer.querySelectorAll(this.itemSelector)); + this.items.forEach((element) => { + element.classList.add("is-idle"); + }); + } + return this.items; + } + + getIdleItems() { + return this.getAllItems().filter((item) => item.classList.contains("is-idle")); + } + + isItemAbove(item) { + return item.hasAttribute("data-is-above"); + } + + isItemToggled(item) { + return item.hasAttribute("data-is-toggled"); + } + + on(source, event, listener, options) { + listener = listener.bind(this); + source.addEventListener(event, listener, options); + return () => source.removeEventListener(event, listener); + } + + dragStart(e) { + if (e.target.classList.contains(this.handleClass)) { + this.draggableItem = e.target.closest(this.itemSelector); + } + + if (!this.draggableItem) return; + + this.pointerStartX = e.clientX || e.touches[0].clientX; + this.pointerStartY = e.clientY || e.touches[0].clientY; + this.scrollYMax = this.listContainer.scrollHeight - this.listContainer.clientHeight; + + this.setItemsGap(); + this.initDraggableItem(); + this.initItemsState(); + + this.offDrag.push(this.on(document, "mousemove", this.drag)); + this.offDrag.push(this.on(document, "touchmove", this.drag, { passive: false })); + + this.dispatchEvent( + new CustomEvent("dragstart", { + detail: { element: this.draggableItem, position: this.getAllItems().indexOf(this.draggableItem) }, + }) + ); + } + + setItemsGap() { + if (this.getIdleItems().length <= 1) { + this.itemsGap = 0; + return; + } + + const item1 = this.getIdleItems()[0]; + const item2 = this.getIdleItems()[1]; + + const item1Rect = item1.getBoundingClientRect(); + const item2Rect = item2.getBoundingClientRect(); + + this.itemsGap = Math.abs(item1Rect.bottom - item2Rect.top); + } + + initItemsState() { + this.getIdleItems().forEach((item, i) => { + if (this.getAllItems().indexOf(this.draggableItem) > i) { + item.dataset.isAbove = ""; + } + }); + } + + initDraggableItem() { + this.draggableItem.classList.remove("is-idle"); + this.draggableItem.classList.add("is-draggable"); + } + + drag(e) { + if (!this.draggableItem) return; + + e.preventDefault(); + + const clientX = e.clientX || e.touches[0].clientX; + const clientY = e.clientY || e.touches[0].clientY; + + const listRect = this.listContainer.getBoundingClientRect(); + + if (clientY > listRect.bottom) { + if (this.listContainer.scrollTop < this.scrollYMax) { + this.listContainer.scrollBy(0, 10); + this.pointerStartY -= 10; + } + } else if (clientY < listRect.top && this.listContainer.scrollTop > 0) { + this.pointerStartY += 10; + this.listContainer.scrollBy(0, -10); + } + + const pointerOffsetX = clientX - this.pointerStartX; + const pointerOffsetY = clientY - this.pointerStartY; + + this.updateIdleItemsStateAndPosition(); + this.draggableItem.style.transform = `translate(${pointerOffsetX}px, ${pointerOffsetY}px)`; + } + + updateIdleItemsStateAndPosition() { + const draggableItemRect = this.draggableItem.getBoundingClientRect(); + const draggableItemY = draggableItemRect.top + draggableItemRect.height / 2; + + // Update state + this.getIdleItems().forEach((item) => { + const itemRect = item.getBoundingClientRect(); + const itemY = itemRect.top + itemRect.height / 2; + if (this.isItemAbove(item)) { + if (draggableItemY <= itemY) { + item.dataset.isToggled = ""; + } else { + delete item.dataset.isToggled; + } + } else { + if (draggableItemY >= itemY) { + item.dataset.isToggled = ""; + } else { + delete item.dataset.isToggled; + } + } + }); + + // Update position + this.getIdleItems().forEach((item) => { + if (this.isItemToggled(item)) { + const direction = this.isItemAbove(item) ? 1 : -1; + item.style.transform = `translateY(${direction * (draggableItemRect.height + this.itemsGap)}px)`; + } else { + item.style.transform = ""; + } + }); + } + + dragEnd() { + if (!this.draggableItem) return; + + this.applyNewItemsOrder(); + this.cleanup(); + } + + applyNewItemsOrder() { + const reorderedItems = []; + + let oldPosition = -1; + this.getAllItems().forEach((item, index) => { + if (item === this.draggableItem) { + oldPosition = index; + return; + } + if (!this.isItemToggled(item)) { + reorderedItems[index] = item; + return; + } + const newIndex = this.isItemAbove(item) ? index + 1 : index - 1; + reorderedItems[newIndex] = item; + }); + + for (let index = 0; index < this.getAllItems().length; index++) { + const item = reorderedItems[index]; + if (typeof item === "undefined") { + reorderedItems[index] = this.draggableItem; + } + } + + reorderedItems.forEach((item) => { + this.listContainer.appendChild(item); + }); + + this.items = reorderedItems; + + this.dispatchEvent( + new CustomEvent("dragend", { + detail: { element: this.draggableItem, oldPosition, newPosition: reorderedItems.indexOf(this.draggableItem) }, + }) + ); + } + + cleanup() { + this.itemsGap = 0; + this.items = []; + this.unsetDraggableItem(); + this.unsetItemState(); + + this.offDrag.forEach((f) => f()); + this.offDrag = []; + } + + unsetDraggableItem() { + this.draggableItem.style = null; + this.draggableItem.classList.remove("is-draggable"); + this.draggableItem.classList.add("is-idle"); + this.draggableItem = null; + } + + unsetItemState() { + this.getIdleItems().forEach((item, i) => { + delete item.dataset.isAbove; + delete item.dataset.isToggled; + item.style.transform = ""; + }); + } + + dispose() { + this.off.forEach((f) => f()); + } +} diff --git a/comfy/web/scripts/ui/settings.js b/comfy/web/scripts/ui/settings.js new file mode 100644 index 000000000..1cdba5cfe --- /dev/null +++ b/comfy/web/scripts/ui/settings.js @@ -0,0 +1,307 @@ +import { $el } from "../ui.js"; +import { api } from "../api.js"; +import { ComfyDialog } from "./dialog.js"; + +export class ComfySettingsDialog extends ComfyDialog { + constructor(app) { + super(); + this.app = app; + this.settingsValues = {}; + this.settingsLookup = {}; + this.element = $el( + "dialog", + { + id: "comfy-settings-dialog", + parent: document.body, + }, + [ + $el("table.comfy-modal-content.comfy-table", [ + $el("caption", { textContent: "Settings" }), + $el("tbody", { $: (tbody) => (this.textElement = tbody) }), + $el("button", { + type: "button", + textContent: "Close", + style: { + cursor: "pointer", + }, + onclick: () => { + this.element.close(); + }, + }), + ]), + ] + ); + } + + get settings() { + return Object.values(this.settingsLookup); + } + + async load() { + if (this.app.storageLocation === "browser") { + this.settingsValues = localStorage; + } else { + this.settingsValues = await api.getSettings(); + } + + // Trigger onChange for any settings added before load + for (const id in this.settingsLookup) { + this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]); + } + } + + getId(id) { + if (this.app.storageLocation === "browser") { + id = "Comfy.Settings." + id; + } + return id; + } + + getSettingValue(id, defaultValue) { + let value = this.settingsValues[this.getId(id)]; + if(value != null) { + if(this.app.storageLocation === "browser") { + try { + value = JSON.parse(value); + } catch (error) { + } + } + } + return value ?? defaultValue; + } + + async setSettingValueAsync(id, value) { + const json = JSON.stringify(value); + localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage + + let oldValue = this.getSettingValue(id, undefined); + this.settingsValues[this.getId(id)] = value; + + if (id in this.settingsLookup) { + this.settingsLookup[id].onChange?.(value, oldValue); + } + + await api.storeSetting(id, value); + } + + setSettingValue(id, value) { + this.setSettingValueAsync(id, value).catch((err) => { + alert(`Error saving setting '${id}'`); + console.error(err); + }); + } + + addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) { + if (!id) { + throw new Error("Settings must have an ID"); + } + + if (id in this.settingsLookup) { + throw new Error(`Setting ${id} of type ${type} must have a unique ID.`); + } + + let skipOnChange = false; + let value = this.getSettingValue(id); + if (value == null) { + if (this.app.isNewUserSession) { + // Check if we have a localStorage value but not a setting value and we are a new user + const localValue = localStorage["Comfy.Settings." + id]; + if (localValue) { + value = JSON.parse(localValue); + this.setSettingValue(id, value); // Store on the server + } + } + if (value == null) { + value = defaultValue; + } + } + + // Trigger initial setting of value + if (!skipOnChange) { + onChange?.(value, undefined); + } + + this.settingsLookup[id] = { + id, + onChange, + name, + render: () => { + const setter = (v) => { + if (onChange) { + onChange(v, value); + } + + this.setSettingValue(id, v); + value = v; + }; + value = this.getSettingValue(id, defaultValue); + + let element; + const htmlID = id.replaceAll(".", "-"); + + const labelCell = $el("td", [ + $el("label", { + for: htmlID, + classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""], + textContent: name, + }), + ]); + + if (typeof type === "function") { + element = type(name, setter, value, attrs); + } else { + switch (type) { + case "boolean": + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + id: htmlID, + type: "checkbox", + checked: value, + onchange: (event) => { + const isChecked = event.target.checked; + if (onChange !== undefined) { + onChange(isChecked); + } + this.setSettingValue(id, isChecked); + }, + }), + ]), + ]); + break; + case "number": + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + type, + value, + id: htmlID, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs, + }), + ]), + ]); + break; + case "slider": + element = $el("tr", [ + labelCell, + $el("td", [ + $el( + "div", + { + style: { + display: "grid", + gridAutoFlow: "column", + }, + }, + [ + $el("input", { + ...attrs, + value, + type: "range", + oninput: (e) => { + setter(e.target.value); + e.target.nextElementSibling.value = e.target.value; + }, + }), + $el("input", { + ...attrs, + value, + id: htmlID, + type: "number", + style: { maxWidth: "4rem" }, + oninput: (e) => { + setter(e.target.value); + e.target.previousElementSibling.value = e.target.value; + }, + }), + ] + ), + ]), + ]); + break; + case "combo": + element = $el("tr", [ + labelCell, + $el("td", [ + $el( + "select", + { + oninput: (e) => { + setter(e.target.value); + }, + }, + (typeof options === "function" ? options(value) : options || []).map((opt) => { + if (typeof opt === "string") { + opt = { text: opt }; + } + const v = opt.value ?? opt.text; + return $el("option", { + value: v, + textContent: opt.text, + selected: value + "" === v + "", + }); + }) + ), + ]), + ]); + break; + case "text": + default: + if (type !== "text") { + console.warn(`Unsupported setting type '${type}, defaulting to text`); + } + + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + value, + id: htmlID, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs, + }), + ]), + ]); + break; + } + } + if (tooltip) { + element.title = tooltip; + } + + return element; + }, + }; + + const self = this; + return { + get value() { + return self.getSettingValue(id, defaultValue); + }, + set value(v) { + self.setSettingValue(id, v); + }, + }; + } + + show() { + this.textElement.replaceChildren( + $el( + "tr", + { + style: { display: "none" }, + }, + [$el("th"), $el("th", { style: { width: "33%" } })] + ), + ...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render()) + ); + this.element.showModal(); + } +} diff --git a/comfy/web/scripts/ui/spinner.css b/comfy/web/scripts/ui/spinner.css new file mode 100644 index 000000000..56da6072e --- /dev/null +++ b/comfy/web/scripts/ui/spinner.css @@ -0,0 +1,34 @@ +.lds-ring { + display: inline-block; + position: relative; + width: 1em; + height: 1em; +} +.lds-ring div { + box-sizing: border-box; + display: block; + position: absolute; + width: 100%; + height: 100%; + border: 0.15em solid #fff; + border-radius: 50%; + animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite; + border-color: #fff transparent transparent transparent; +} +.lds-ring div:nth-child(1) { + animation-delay: -0.45s; +} +.lds-ring div:nth-child(2) { + animation-delay: -0.3s; +} +.lds-ring div:nth-child(3) { + animation-delay: -0.15s; +} +@keyframes lds-ring { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} diff --git a/comfy/web/scripts/ui/spinner.js b/comfy/web/scripts/ui/spinner.js new file mode 100644 index 000000000..d049786f6 --- /dev/null +++ b/comfy/web/scripts/ui/spinner.js @@ -0,0 +1,9 @@ +import { addStylesheet } from "../utils.js"; + +addStylesheet(import.meta.url); + +export function createSpinner() { + const div = document.createElement("div"); + div.innerHTML = `
`; + return div.firstElementChild; +} diff --git a/comfy/web/scripts/ui/toggleSwitch.js b/comfy/web/scripts/ui/toggleSwitch.js new file mode 100644 index 000000000..59597ef90 --- /dev/null +++ b/comfy/web/scripts/ui/toggleSwitch.js @@ -0,0 +1,60 @@ +import { $el } from "../ui.js"; + +/** + * @typedef { { text: string, value?: string, tooltip?: string } } ToggleSwitchItem + */ +/** + * Creates a toggle switch element + * @param { string } name + * @param { Array void } [opts.onChange] + */ +export function toggleSwitch(name, items, { onChange } = {}) { + let selectedIndex; + let elements; + + function updateSelected(index) { + if (selectedIndex != null) { + elements[selectedIndex].classList.remove("comfy-toggle-selected"); + } + onChange?.({ item: items[index], prev: selectedIndex == null ? undefined : items[selectedIndex] }); + selectedIndex = index; + elements[selectedIndex].classList.add("comfy-toggle-selected"); + } + + elements = items.map((item, i) => { + if (typeof item === "string") item = { text: item }; + if (!item.value) item.value = item.text; + + const toggle = $el( + "label", + { + textContent: item.text, + title: item.tooltip ?? "", + }, + $el("input", { + name, + type: "radio", + value: item.value ?? item.text, + checked: item.selected, + onchange: () => { + updateSelected(i); + }, + }) + ); + if (item.selected) { + updateSelected(i); + } + return toggle; + }); + + const container = $el("div.comfy-toggle-switch", elements); + + if (selectedIndex == null) { + elements[0].children[0].checked = true; + updateSelected(0); + } + + return container; +} diff --git a/comfy/web/scripts/ui/userSelection.css b/comfy/web/scripts/ui/userSelection.css new file mode 100644 index 000000000..35c9d6614 --- /dev/null +++ b/comfy/web/scripts/ui/userSelection.css @@ -0,0 +1,135 @@ +.comfy-user-selection { + width: 100vw; + height: 100vh; + position: absolute; + top: 0; + left: 0; + z-index: 999; + display: flex; + align-items: center; + justify-content: center; + font-family: sans-serif; + background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color)); +} + +.comfy-user-selection-inner { + background: var(--comfy-menu-bg); + margin-top: -30vh; + padding: 20px 40px; + border-radius: 10px; + min-width: 365px; + position: relative; + box-shadow: 0 0 20px rgba(0, 0, 0, 0.3); +} + +.comfy-user-selection-inner form { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; +} + +.comfy-user-selection-inner h1 { + margin: 10px 0 30px 0; + font-weight: normal; +} + +.comfy-user-selection-inner label { + display: flex; + flex-direction: column; + width: 100%; +} + +.comfy-user-selection input, +.comfy-user-selection select { + background-color: var(--comfy-input-bg); + color: var(--input-text); + border: 0; + border-radius: 5px; + padding: 5px; + margin-top: 10px; +} + +.comfy-user-selection input::placeholder { + color: var(--descrip-text); + opacity: 1; +} + +.comfy-user-existing { + width: 100%; +} + +.no-users .comfy-user-existing { + display: none; +} + +.comfy-user-selection-inner .or-separator { + margin: 10px 0; + padding: 10px; + display: block; + text-align: center; + width: 100%; + color: var(--descrip-text); +} + +.comfy-user-selection-inner .or-separator { + overflow: hidden; + text-align: center; + margin-left: -10px; +} + +.comfy-user-selection-inner .or-separator::before, +.comfy-user-selection-inner .or-separator::after { + content: ""; + background-color: var(--border-color); + position: relative; + height: 1px; + vertical-align: middle; + display: inline-block; + width: calc(50% - 20px); + top: -1px; +} + +.comfy-user-selection-inner .or-separator::before { + right: 10px; + margin-left: -50%; +} + +.comfy-user-selection-inner .or-separator::after { + left: 10px; + margin-right: -50%; +} + +.comfy-user-selection-inner section { + width: 100%; + padding: 10px; + margin: -10px; + transition: background-color 0.2s; +} + +.comfy-user-selection-inner section.selected { + background: var(--border-color); + border-radius: 5px; +} + +.comfy-user-selection-inner footer { + display: flex; + flex-direction: column; + align-items: center; + margin-top: 20px; +} + +.comfy-user-selection-inner .comfy-user-error { + color: var(--error-text); + margin-bottom: 10px; +} + +.comfy-user-button-next { + font-size: 16px; + padding: 6px 10px; + width: 100px; + display: flex; + gap: 5px; + align-items: center; + justify-content: center; +} \ No newline at end of file diff --git a/comfy/web/scripts/ui/userSelection.js b/comfy/web/scripts/ui/userSelection.js new file mode 100644 index 000000000..f9f1ca807 --- /dev/null +++ b/comfy/web/scripts/ui/userSelection.js @@ -0,0 +1,114 @@ +import { api } from "../api.js"; +import { $el } from "../ui.js"; +import { addStylesheet } from "../utils.js"; +import { createSpinner } from "./spinner.js"; + +export class UserSelectionScreen { + async show(users, user) { + // This will rarely be hit so move the loading to on demand + await addStylesheet(import.meta.url); + const userSelection = document.getElementById("comfy-user-selection"); + userSelection.style.display = ""; + return new Promise((resolve) => { + const input = userSelection.getElementsByTagName("input")[0]; + const select = userSelection.getElementsByTagName("select")[0]; + const inputSection = input.closest("section"); + const selectSection = select.closest("section"); + const form = userSelection.getElementsByTagName("form")[0]; + const error = userSelection.getElementsByClassName("comfy-user-error")[0]; + const button = userSelection.getElementsByClassName("comfy-user-button-next")[0]; + + let inputActive = null; + input.addEventListener("focus", () => { + inputSection.classList.add("selected"); + selectSection.classList.remove("selected"); + inputActive = true; + }); + select.addEventListener("focus", () => { + inputSection.classList.remove("selected"); + selectSection.classList.add("selected"); + inputActive = false; + select.style.color = ""; + }); + select.addEventListener("blur", () => { + if (!select.value) { + select.style.color = "var(--descrip-text)"; + } + }); + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + if (inputActive == null) { + error.textContent = "Please enter a username or select an existing user."; + } else if (inputActive) { + const username = input.value.trim(); + if (!username) { + error.textContent = "Please enter a username."; + return; + } + + // Create new user + input.disabled = select.disabled = input.readonly = select.readonly = true; + const spinner = createSpinner(); + button.prepend(spinner); + try { + const resp = await api.createUser(username); + if (resp.status >= 300) { + let message = "Error creating user: " + resp.status + " " + resp.statusText; + try { + const res = await resp.json(); + if(res.error) { + message = res.error; + } + } catch (error) { + } + throw new Error(message); + } + + resolve({ username, userId: await resp.json(), created: true }); + } catch (err) { + spinner.remove(); + error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred."; + input.disabled = select.disabled = input.readonly = select.readonly = false; + return; + } + } else if (!select.value) { + error.textContent = "Please select an existing user."; + return; + } else { + resolve({ username: users[select.value], userId: select.value, created: false }); + } + }); + + if (user) { + const name = localStorage["Comfy.userName"]; + if (name) { + input.value = name; + } + } + if (input.value) { + // Focus the input, do this separately as sometimes browsers like to fill in the value + input.focus(); + } + + const userIds = Object.keys(users ?? {}); + if (userIds.length) { + for (const u of userIds) { + $el("option", { textContent: users[u], value: u, parent: select }); + } + select.style.color = "var(--descrip-text)"; + + if (select.value) { + // Focus the select, do this separately as sometimes browsers like to fill in the value + select.focus(); + } + } else { + userSelection.classList.add("no-users"); + input.focus(); + } + }).then((r) => { + userSelection.remove(); + return r; + }); + } +} diff --git a/comfy/web/scripts/utils.js b/comfy/web/scripts/utils.js index 401aca9e4..01b988462 100644 --- a/comfy/web/scripts/utils.js +++ b/comfy/web/scripts/utils.js @@ -1,3 +1,5 @@ +import { $el } from "./ui.js"; + // Simple date formatter const parts = { d: (d) => d.getDate(), @@ -65,3 +67,22 @@ export function applyTextReplacements(app, value) { return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); }); } + +export async function addStylesheet(urlOrFile, relativeTo) { + return new Promise((res, rej) => { + let url; + if (urlOrFile.endsWith(".js")) { + url = urlOrFile.substr(0, urlOrFile.length - 2) + "css"; + } else { + url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString(); + } + $el("link", { + parent: document.head, + rel: "stylesheet", + type: "text/css", + href: url, + onload: res, + onerror: rej, + }); + }); +} diff --git a/comfy/web/scripts/widgets.js b/comfy/web/scripts/widgets.js index e2e21164d..0529b1d80 100644 --- a/comfy/web/scripts/widgets.js +++ b/comfy/web/scripts/widgets.js @@ -1,6 +1,19 @@ import { api } from "./api.js" import "./domWidget.js"; +let controlValueRunBefore = false; +export function updateControlWidgetLabel(widget) { + let replacement = "after"; + let find = "before"; + if (controlValueRunBefore) { + [find, replacement] = [replacement, find] + } + widget.label = (widget.label ?? widget.name).replace(find, replacement); +} + +const IS_CONTROL_WIDGET = Symbol(); +const HAS_EXECUTED = Symbol(); + function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -62,6 +75,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando serialize: false, // Don't include this in prompt. } ); + valueControl[IS_CONTROL_WIDGET] = true; + updateControlWidgetLabel(valueControl); widgets.push(valueControl); const isCombo = targetWidget.type === "combo"; @@ -76,10 +91,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando serialize: false, // Don't include this in prompt. } ); + updateControlWidgetLabel(comboFilter); + widgets.push(comboFilter); } - valueControl.afterQueued = () => { + const applyWidgetControl = () => { var v = valueControl.value; if (isCombo && v !== "fixed") { @@ -159,6 +176,23 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando targetWidget.callback(targetWidget.value); } }; + + valueControl.beforeQueued = () => { + if (controlValueRunBefore) { + // Don't run on first execution + if (valueControl[HAS_EXECUTED]) { + applyWidgetControl(); + } + } + valueControl[HAS_EXECUTED] = true; + }; + + valueControl.afterQueued = () => { + if (!controlValueRunBefore) { + applyWidgetControl(); + } + }; + return widgets; }; @@ -224,6 +258,34 @@ function isSlider(display, app) { return (display==="slider") ? "slider" : "number" } +export function initWidgets(app) { + app.ui.settings.addSetting({ + id: "Comfy.WidgetControlMode", + name: "Widget Value Control Mode", + type: "combo", + defaultValue: "after", + options: ["before", "after"], + tooltip: "Controls when widget values are updated (randomize/increment/decrement), either before the prompt is queued or after.", + onChange(value) { + controlValueRunBefore = value === "before"; + for (const n of app.graph._nodes) { + if (!n.widgets) continue; + for (const w of n.widgets) { + if (w[IS_CONTROL_WIDGET]) { + updateControlWidgetLabel(w); + if (w.linkedWidgets) { + for (const l of w.linkedWidgets) { + updateControlWidgetLabel(l); + } + } + } + } + } + app.graph.setDirtyCanvas(true); + }, + }); +} + export const ComfyWidgets = { "INT:seed": seedWidget, "INT:noise_seed": seedWidget, diff --git a/comfy/web/style.css b/comfy/web/style.css index 630eea12e..44ee60198 100644 --- a/comfy/web/style.css +++ b/comfy/web/style.css @@ -121,6 +121,8 @@ body { width: 100%; } +.comfy-toggle-switch, +.comfy-btn, .comfy-menu > button, .comfy-menu-btns button, .comfy-menu .comfy-list button, @@ -133,6 +135,7 @@ body { margin-top: 2px; } +.comfy-btn:hover:not(:disabled), .comfy-menu > button:hover, .comfy-menu-btns button:hover, .comfy-menu .comfy-list button:hover, @@ -143,6 +146,12 @@ body { } .comfy-menu span.drag-handle { + position: absolute; + top: 0; + left: 0; +} + +span.drag-handle { width: 10px; height: 20px; display: inline-block; @@ -158,12 +167,9 @@ body { letter-spacing: 2px; color: var(--drag-text); text-shadow: 1px 0 1px black; - position: absolute; - top: 0; - left: 0; } -.comfy-menu span.drag-handle::after { +span.drag-handle::after { content: '.. .. ..'; } @@ -429,6 +435,43 @@ dialog::backdrop { margin-left: 5px; } +.comfy-toggle-switch { + border-width: 2px; + display: flex; + background-color: var(--comfy-input-bg); + margin: 2px 0; + white-space: nowrap; +} + +.comfy-toggle-switch label { + padding: 2px 0px 3px 6px; + flex: auto; + border-radius: 8px; + align-items: center; + display: flex; + justify-content: center; +} + +.comfy-toggle-switch label:first-child { + border-top-left-radius: 8px; + border-bottom-left-radius: 8px; +} +.comfy-toggle-switch label:last-child { + border-top-right-radius: 8px; + border-bottom-right-radius: 8px; +} + +.comfy-toggle-switch .comfy-toggle-selected { + background-color: var(--comfy-menu-bg); +} + +#extraOptions { + padding: 4px; + background-color: var(--bg-color); + margin-bottom: 4px; + border-radius: 4px; +} + /* Search box */ .litegraph.litesearchbox { diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index baf72ecc5..39c7e28e6 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -1,4 +1,5 @@ from comfy import samplers +from comfy import model_management from comfy import sample from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.cmd import latent_preview @@ -26,9 +27,8 @@ class BasicScheduler: if denoise < 1.0: total_steps = int(steps/denoise) - inner_model = model.patch_model(patch_weights=False) - sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu() - model.unpatch_model() + model_management.load_models_gpu([model]) + sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -106,9 +106,8 @@ class SDTurboScheduler: def get_sigmas(self, model, steps, denoise): start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] - inner_model = model.patch_model(patch_weights=False) - sigmas = inner_model.model_sampling.sigma(timesteps) - model.unpatch_model() + model_management.load_models_gpu([model]) + sigmas = model.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) diff --git a/comfy_extras/nodes/nodes_freelunch.py b/comfy_extras/nodes/nodes_freelunch.py index 7512b841d..7764aa0b0 100644 --- a/comfy_extras/nodes/nodes_freelunch.py +++ b/comfy_extras/nodes/nodes_freelunch.py @@ -34,7 +34,7 @@ class FreeU: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] @@ -73,7 +73,7 @@ class FreeU_V2: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] diff --git a/comfy_extras/nodes/nodes_hypertile.py b/comfy_extras/nodes/nodes_hypertile.py index e7446b2e5..ae55d23dd 100644 --- a/comfy_extras/nodes/nodes_hypertile.py +++ b/comfy_extras/nodes/nodes_hypertile.py @@ -32,29 +32,29 @@ class HyperTile: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, tile_size, swap_size, max_depth, scale_depth): model_channels = model.model.model_config.unet_config["model_channels"] - apply_to = set() - temp = model_channels - for x in range(max_depth + 1): - apply_to.add(temp) - temp *= 2 - latent_tile_size = max(32, tile_size) // 8 self.temp = None def hypertile_in(q, k, v, extra_options): - if q.shape[-1] in apply_to: + model_chans = q.shape[-2] + orig_shape = extra_options['original_shape'] + apply_to = [] + for i in range(max_depth + 1): + apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) + + if model_chans in apply_to: shape = extra_options["original_shape"] aspect_ratio = shape[-1] / shape[-2] hw = q.size(1) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) - factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 + factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 nh = random_divisor(h, latent_tile_size * factor, swap_size) nw = random_divisor(w, latent_tile_size * factor, swap_size) diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py index 2eefc4c55..b7fd8cd68 100644 --- a/comfy_extras/nodes/nodes_latent.py +++ b/comfy_extras/nodes/nodes_latent.py @@ -122,10 +122,34 @@ class LatentBatch: samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) return (samples_out,) +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, } diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py index 1ed455f04..239b69809 100644 --- a/comfy_extras/nodes/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -1,7 +1,7 @@ import numpy as np import scipy.ndimage import torch -import comfy.utils +from comfy import utils from comfy.nodes.common import MAX_RESOLUTION @@ -11,7 +11,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") - source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) + source = utils.repeat_to_batch_size(source, destination.shape[0]) x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) @@ -24,7 +24,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou else: mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") - mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) + mask = utils.repeat_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds diff --git a/comfy_extras/nodes/nodes_model_downscale.py b/comfy_extras/nodes/nodes_model_downscale.py index 48bcc6892..c1b116c97 100644 --- a/comfy_extras/nodes/nodes_model_downscale.py +++ b/comfy_extras/nodes/nodes_model_downscale.py @@ -1,5 +1,5 @@ import torch -import comfy.utils +from comfy import utils class PatchModelAddDownscale: upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @@ -27,12 +27,12 @@ class PatchModelAddDownscale: if transformer_options["block"][1] == block_number: sigma = transformer_options["sigmas"][0].item() if sigma <= sigma_start and sigma >= sigma_end: - h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") + h = utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") return h def output_block_patch(h, hsp, transformer_options): if h.shape[2] != hsp.shape[2]: - h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") + h = utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") return h, hsp m = model.clone() diff --git a/comfy_extras/nodes/nodes_model_merging.py b/comfy_extras/nodes/nodes_model_merging.py index a74a1ad5a..b46f41576 100644 --- a/comfy_extras/nodes/nodes_model_merging.py +++ b/comfy_extras/nodes/nodes_model_merging.py @@ -1,6 +1,6 @@ -from comfy import sd +from comfy import sd, utils from comfy import model_base -import comfy.model_management +from comfy import model_management from comfy.cmd import folder_paths import json @@ -118,6 +118,48 @@ class ModelMergeBlocks: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + + enable_modelspec = True + if isinstance(model.model, model_base.SDXL): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + if model.model.model_type == model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + class CheckpointSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -136,46 +178,7 @@ class CheckpointSave: CATEGORY = "advanced/model_merging" def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - prompt_info = "" - if prompt is not None: - prompt_info = json.dumps(prompt) - - metadata = {} - - enable_modelspec = True - if isinstance(model.model, model_base.SDXL): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" - elif isinstance(model.model, model_base.SDXLRefiner): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" - else: - enable_modelspec = False - - if enable_modelspec: - metadata["modelspec.sai_model_spec"] = "1.0.0" - metadata["modelspec.implementation"] = "sgm" - metadata["modelspec.title"] = "{} {}".format(filename, counter) - - #TODO: - # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", - # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", - # "v2-inpainting" - - if model.model.model_type == model_base.ModelType.EPS: - metadata["modelspec.predict_key"] = "epsilon" - elif model.model.model_type == model_base.ModelType.V_PREDICTION: - metadata["modelspec.predict_key"] = "v" - - if not args.disable_metadata: - metadata["prompt"] = prompt_info - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - output_checkpoint = f"{filename}_{counter:05}_.safetensors" - output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - - sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} class CLIPSave: @@ -205,7 +208,7 @@ class CLIPSave: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) - comfy.model_management.load_models_gpu([clip.load_model()]) + model_management.load_models_gpu([clip.load_model()]) clip_sd = clip.get_sd() for prefix in ["clip_l.", "clip_g.", ""]: @@ -229,9 +232,9 @@ class CLIPSave: output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) + current_clip_sd = utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) - comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) + utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) return {} class VAESave: @@ -265,7 +268,7 @@ class VAESave: output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) + utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) return {} NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes/nodes_photomaker.py b/comfy_extras/nodes/nodes_photomaker.py new file mode 100644 index 000000000..4e99b9e59 --- /dev/null +++ b/comfy_extras/nodes/nodes_photomaker.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +from comfy.cmd import folder_paths +from comfy import clip_model +from comfy import clip_vision +from comfy import ops + +# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "image_size": 224, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "hidden_act": "quick_gelu", +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=ops): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = operations.LayerNorm(in_dim) + self.fc1 = operations.Linear(in_dim, hidden_dim) + self.fc2 = operations.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim, operations): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) + self.layer_norm = operations.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(clip_model.CLIPVisionModelProjection): + def __init__(self): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + super().__init__(VISION_CONFIG_DICT, dtype, offload_device, ops.manual_cast) + self.visual_projection_2 = ops.manual_cast.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048, ops.manual_cast) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[2] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + + RETURN_TYPES = ("PHOTOMAKER",) + FUNCTION = "load_photomaker_model" + + CATEGORY = "_for_testing/photomaker" + + def load_photomaker_model(self, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return (photomaker_model,) + + +class PhotoMakerEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker": ("PHOTOMAKER",), + "image": ("IMAGE",), + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_photomaker" + + CATEGORY = "_for_testing/photomaker" + + def apply_photomaker(self, photomaker, image, clip, text): + special_token = "photomaker" + pixel_values = clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() + try: + index = text.split(" ").index(special_token) + 1 + except ValueError: + index = -1 + tokens = clip.tokenize(text, return_word_ids=True) + out_tokens = {} + for k in tokens: + out_tokens[k] = [] + for t in tokens[k]: + f = list(filter(lambda x: x[2] != index, t)) + while len(f) < len(t): + f.append(t[-1]) + out_tokens[k].append(f) + + cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + + if index > 0: + token_index = index - 1 + num_id_images = 1 + class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] + out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) + else: + out = cond + + return ([[out, {"pooled_output": pooled}]], ) + + +NODE_CLASS_MAPPINGS = { + "PhotoMakerLoader": PhotoMakerLoader, + "PhotoMakerEncode": PhotoMakerEncode, +} + diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py index 334214ad1..7012c2702 100644 --- a/comfy_extras/nodes/nodes_post_processing.py +++ b/comfy_extras/nodes/nodes_post_processing.py @@ -33,6 +33,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py index 66606e328..16ccf04b1 100644 --- a/comfy_extras/nodes/nodes_sag.py +++ b/comfy_extras/nodes/nodes_sag.py @@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = math.ceil(math.sqrt(lh * lw / hw1)) + ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape @@ -143,6 +143,8 @@ class SelfAttentionGuidance: sigma = args["sigma"] model_options = args["model_options"] x = args["input"] + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding + return cfg_result # create the adversarially blurred image degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) diff --git a/comfy_extras/nodes/nodes_stable3d.py b/comfy_extras/nodes/nodes_stable3d.py index 7aa6ec858..05c76926b 100644 --- a/comfy_extras/nodes/nodes_stable3d.py +++ b/comfy_extras/nodes/nodes_stable3d.py @@ -1,5 +1,4 @@ import torch -import comfy.utils from comfy.nodes.common import MAX_RESOLUTION from comfy import utils @@ -49,13 +48,57 @@ class StableZero123_Conditioning: encode_pixels = pixels[:,:,:,:3] t = vae.encode(encode_pixels) cam_embeds = camera_embeddings(elevation, azimuth) - cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1) positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent}) +class StableZero123_Conditioning_Batched: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + cam_embeds = [] + for i in range(batch_size): + cam_embeds.append(camera_embeddings(elevation, azimuth)) + elevation += elevation_batch_increment + azimuth += azimuth_batch_increment + + cam_embeds = torch.cat(cam_embeds, dim=0) + cond = torch.cat([utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, + "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, } diff --git a/comfy_extras/nodes/nodes_video_model.py b/comfy_extras/nodes/nodes_video_model.py index 8b8584100..4f54b5459 100644 --- a/comfy_extras/nodes/nodes_video_model.py +++ b/comfy_extras/nodes/nodes_video_model.py @@ -3,6 +3,7 @@ import torch import comfy.utils import comfy.sd from comfy.cmd import folder_paths +from . import nodes_model_merging class ImageOnlyCheckpointLoader: @@ -78,10 +79,26 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave): + CATEGORY = "_for_testing" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip_vision": ("CLIP_VISION",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + + def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): + nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/models/photomaker/put_photomaker_models_here b/models/photomaker/put_photomaker_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/tests-ui/babel.config.json b/tests-ui/babel.config.json index 526ddfd8d..f27d6c397 100644 --- a/tests-ui/babel.config.json +++ b/tests-ui/babel.config.json @@ -1,3 +1,4 @@ { - "presets": ["@babel/preset-env"] + "presets": ["@babel/preset-env"], + "plugins": ["babel-plugin-transform-import-meta"] } diff --git a/tests-ui/package-lock.json b/tests-ui/package-lock.json index 35911cd7f..0f409ca24 100644 --- a/tests-ui/package-lock.json +++ b/tests-ui/package-lock.json @@ -11,6 +11,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } @@ -2591,6 +2592,19 @@ "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" } }, + "node_modules/babel-plugin-transform-import-meta": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz", + "integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==", + "dev": true, + "dependencies": { + "@babel/template": "^7.4.4", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@babel/core": "^7.10.0" + } + }, "node_modules/babel-preset-current-node-syntax": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", @@ -5233,6 +5247,12 @@ "node": ">=12" } }, + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "dev": true + }, "node_modules/type-detect": { "version": "4.0.8", "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", diff --git a/tests-ui/package.json b/tests-ui/package.json index e7b60ad8e..ae7e49084 100644 --- a/tests-ui/package.json +++ b/tests-ui/package.json @@ -24,6 +24,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } diff --git a/tests-ui/tests/users.test.js b/tests-ui/tests/users.test.js new file mode 100644 index 000000000..07a7017ba --- /dev/null +++ b/tests-ui/tests/users.test.js @@ -0,0 +1,295 @@ +// @ts-check +/// +const { start } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("users", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + function expectNoUserScreen() { + // Ensure login isnt visible + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection["style"].display).toBe("none"); + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + } + + describe("multi-user", () => { + function mockAddStylesheet() { + const utils = require("../../comfy/web/scripts/utils"); + utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve()); + } + + async function waitForUserScreenShow() { + mockAddStylesheet(); + + // Wait for "show" to be called + const { UserSelectionScreen } = require("../../comfy/web/scripts/ui/userSelection"); + let resolve, reject; + const fn = UserSelectionScreen.prototype.show; + const p = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => { + const res = fn(...args); + await new Promise(process.nextTick); // wait for promises to resolve + resolve(); + return res; + }); + // @ts-ignore + setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500); + await p; + await new Promise(process.nextTick); // wait for promises to resolve + } + + async function testUserScreen(onShown, users) { + if (!users) { + users = {}; + } + const starting = start({ + resetEnv: true, + userConfig: { storage: "server", users }, + }); + + // Ensure no current user + expect(localStorage["Comfy.userId"]).toBeFalsy(); + expect(localStorage["Comfy.userName"]).toBeFalsy(); + + await waitForUserScreenShow(); + + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection).toBeTruthy(); + + // Ensure login is visible + expect(window.getComputedStyle(selection)?.display).not.toBe("none"); + // Ensure menu is hidden + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).toBe("none"); + + const isCreate = await onShown(selection); + + // Submit form + selection.querySelectorAll("form")[0].submit(); + await new Promise(process.nextTick); // wait for promises to resolve + + // Wait for start + const s = await starting; + + // Ensure login is removed + expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0); + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + + // Ensure settings + templates are saved + const { api } = require("../../comfy/web/scripts/api"); + expect(api.createUser).toHaveBeenCalledTimes(+isCreate); + expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate); + expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate); + if (isCreate) { + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(s.app.isNewUserSession).toBeTruthy(); + } else { + expect(s.app.isNewUserSession).toBeFalsy(); + } + + return { users, selection, ...s }; + } + + it("allows user creation if no users", async () => { + const { users } = await testUserScreen((selection) => { + // Ensure we have no users flag added + expect(selection.classList.contains("no-users")).toBeTruthy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User"; + + return true; + }); + + expect(users).toStrictEqual({ + "Test User!": "Test User", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User!"); + expect(localStorage["Comfy.userName"]).toBe("Test User"); + }); + it("allows user creation if no current user but other users", async () => { + const users = { + "Test User 2!": "Test User 2", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User 3"; + return true; + }, users); + + expect(users).toStrictEqual({ + "Test User 2!": "Test User 2", + "Test User 3!": "Test User 3", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User 3!"); + expect(localStorage["Comfy.userName"]).toBe("Test User 3"); + }); + it("allows user selection if no current user but other users", async () => { + const users = { + "A!": "A", + "B!": "B", + "C!": "C", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Check user list + const select = selection.getElementsByTagName("select")[0]; + const options = select.getElementsByTagName("option"); + expect( + [...options] + .filter((o) => !o.disabled) + .reduce((p, n) => { + p[n.getAttribute("value")] = n.textContent; + return p; + }, {}) + ).toStrictEqual(users); + + // Select an option + select.focus(); + select.value = options[2].value; + + return false; + }, users); + + expect(users).toStrictEqual(users); + + expect(localStorage["Comfy.userId"]).toBe("B!"); + expect(localStorage["Comfy.userName"]).toBe("B"); + }); + it("doesnt show user screen if current user", async () => { + const starting = start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + await new Promise(process.nextTick); // wait for promises to resolve + + expectNoUserScreen(); + + await starting; + }); + it("allows user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + + // cant actually test switching user easily but can check the setting is present + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy(); + }); + }); + describe("single-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../comfy/web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(app.isNewUserSession).toBeTruthy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../comfy/web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); + describe("browser-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "browser" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../comfy/web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../comfy/web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "browser" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); +}); diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 19f530f91..df1b13eef 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -1,10 +1,18 @@ const { mockApi } = require("./setup"); const { Ez } = require("./ezgraph"); const lg = require("./litegraph"); +const fs = require("fs"); +const path = require("path"); + +const html = fs.readFileSync(path.resolve("../comfy/web/index.html")) /** * - * @param { Parameters[0] & { resetEnv?: boolean, preSetup?(app): Promise } } config + * @param { Parameters[0] & { + * resetEnv?: boolean, + * preSetup?(app): Promise, + * localStorage?: Record + * } } config * @returns */ export async function start(config = {}) { @@ -12,12 +20,18 @@ export async function start(config = {}) { jest.resetModules(); jest.resetAllMocks(); lg.setup(global); + localStorage.clear(); + sessionStorage.clear(); } + Object.assign(localStorage, config.localStorage ?? {}); + document.body.innerHTML = html; + mockApi(config); const { app } = require("../../comfy/web/scripts/app"); config.preSetup?.(app); await app.setup(); + return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js index 4807167ee..ec65dff9f 100644 --- a/tests-ui/utils/setup.js +++ b/tests-ui/utils/setup.js @@ -18,9 +18,21 @@ function* walkSync(dir) { */ /** - * @param { { mockExtensions?: string[], mockNodeDefs?: Record } } config + * @param {{ + * mockExtensions?: string[], + * mockNodeDefs?: Record, +* settings?: Record +* userConfig?: {storage: "server" | "browser", users?: Record, migrated?: boolean }, +* userData?: Record + * }} config */ -export function mockApi({ mockExtensions, mockNodeDefs } = {}) { +export function mockApi(config = {}) { + let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = { + userConfig, + settings: {}, + userData: {}, + ...config, + }; if (!mockExtensions) { mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core"))) .filter((x) => x.endsWith(".js")) @@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) { getNodeDefs: jest.fn(() => mockNodeDefs), init: jest.fn(), apiURL: jest.fn((x) => "../../web/" + x), + createUser: jest.fn((username) => { + if(username in userConfig.users) { + return { status: 400, json: () => "Duplicate" } + } + userConfig.users[username + "!"] = username; + return { status: 200, json: () => username + "!" } + }), + getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }), + getSettings: jest.fn(() => settings), + storeSettings: jest.fn((v) => Object.assign(settings, v)), + getUserData: jest.fn((f) => { + if (f in userData) { + return { status: 200, json: () => userData[f] }; + } else { + return { status: 404 }; + } + }), + storeUserData: jest.fn((file, data) => { + userData[file] = data; + }), }; jest.mock("../../comfy/web/scripts/api", () => ({ get api() {