Merge with latest upstream.

This commit is contained in:
doctorpangloss 2024-01-29 14:41:30 -08:00
commit 82edb2ff0e
70 changed files with 3546 additions and 536 deletions

1
.gitignore vendored
View File

@ -173,3 +173,4 @@ cython_debug/
.openapi-generator/ .openapi-generator/
/tests-ui/data/object_info.json /tests-ui/data/object_info.json
/user/

View File

@ -101,6 +101,8 @@ There is a portable standalone build for Windows that should work for running on
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
#### How do I share models between another UI and ComfyUI? #### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -192,7 +194,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
To run tests: To run tests:
```shell ```shell
pytest tests/inference pytest tests/inference
(cd tests-ui && npm ci && npm test:generate && npm test) (cd tests-ui && npm ci && npm run test:generate && npm test)
``` ```
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language. You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.

0
comfy/app/__init__.py Normal file
View File

54
comfy/app/app_settings.py Normal file
View File

@ -0,0 +1,54 @@
import os
import json
from aiohttp import web
class AppSettings():
def __init__(self, user_manager):
self.user_manager = user_manager
def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
else:
return {}
def save_settings(self, request, settings):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
with open(file, "w") as f:
f.write(json.dumps(settings, indent=4))
def add_routes(self, routes):
@routes.get("/settings")
async def get_settings(request):
return web.json_response(self.get_settings(request))
@routes.get("/settings/{id}")
async def get_setting(request):
value = None
settings = self.get_settings(request)
setting_id = request.match_info.get("id", None)
if setting_id and setting_id in settings:
value = settings[setting_id]
return web.json_response(value)
@routes.post("/settings")
async def post_settings(request):
settings = self.get_settings(request)
new_settings = await request.json()
self.save_settings(request, {**settings, **new_settings})
return web.Response(status=200)
@routes.post("/settings/{id}")
async def post_setting(request):
setting_id = request.match_info.get("id", None)
if not setting_id:
return web.Response(status=400)
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)

135
comfy/app/user_manager.py Normal file
View File

@ -0,0 +1,135 @@
import json
import os
import re
import uuid
from aiohttp import web
from ..cli_args import args
from ..cmd.folder_paths import user_directory
from .app_settings import AppSettings
class UserManager():
def __init__(self):
self.default_user = "default"
self.users_file = os.path.join(user_directory, "users.json")
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.users_file):
with open(self.users_file) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
if user not in self.users:
raise KeyError("Unknown user: " + user)
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
return None
parent = user_root
if file is not None:
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
return None
if create_dir and not os.path.exists(parent):
os.mkdir(parent)
return path
def add_user(self, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
with open(self.users_file, "w") as f:
json.dump(self.users, f)
return user_id
def add_routes(self, routes):
self.settings.add_routes(routes)
@routes.get("/users")
async def get_users(request):
if args.multi_user:
return web.json_response({"storage": "server", "users": self.users})
else:
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
return web.json_response({
"storage": "server",
"migrated": os.path.exists(user_dir)
})
@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
return web.json_response(user_id)
@routes.get("/userdata/{file}")
async def getuserdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if not os.path.exists(path):
return web.Response(status=404)
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
body = await request.read()
with open(path, "wb") as f:
f.write(body)
return web.Response(status=200)

View File

@ -112,6 +112,8 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
if options.args_parsing: if options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
else: else:

View File

@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module):
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None): def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None: if intermediate_output is not None:
if intermediate_output < 0: if intermediate_output < 0:

View File

@ -1,4 +1,4 @@
from .utils import load_torch_file, transformers_convert from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os import os
import torch import torch
import json import json
@ -43,6 +43,9 @@ class ClipVisionModel():
def load_sd(self, sd): def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False) return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image): def encode_image(self, image):
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)).float() pixel_values = clip_preprocess(image.to(self.load_device)).float()
@ -75,6 +78,9 @@ def convert_to_transformers(sd, prefix):
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 48) sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): def load_clipvision_from_sd(sd, prefix="", convert_keys=False):

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import sys
import copy import copy
import datetime import datetime
import heapq import heapq
@ -15,6 +16,7 @@ from typing import Tuple
import sys import sys
import gc import gc
import inspect import inspect
from typing import List, Literal, NamedTuple, Optional
import torch import torch
@ -325,11 +327,22 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor: class PromptExecutor:
def __init__(self, server): def __init__(self, server):
self.success = None
self.server = server
self.reset()
def reset(self):
self.outputs = {} self.outputs = {}
self.object_storage = {} self.object_storage = {}
self.outputs_ui = {} self.outputs_ui = {}
self.status_messages = []
self.success = True
self.old_prompt = {} self.old_prompt = {}
self.server = server
def add_message(self, event, data, broadcast: bool):
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"] node_id = error["node_id"]
@ -344,22 +357,21 @@ class PromptExecutor:
"node_type": class_type, "node_type": class_type,
"executed": list(executed), "executed": list(executed),
} }
self.server.send_sync("execution_interrupted", mes, self.server.client_id) self.add_message("execution_interrupted", mes, broadcast=True)
else: else:
if self.server.client_id is not None: mes = {
mes = { "prompt_id": prompt_id,
"prompt_id": prompt_id, "node_id": node_id,
"node_id": node_id, "node_type": class_type,
"node_type": class_type, "executed": list(executed),
"executed": list(executed),
"exception_message": error["exception_message"], "exception_message": error["exception_message"],
"exception_type": error["exception_type"], "exception_type": error["exception_type"],
"traceback": error["traceback"], "traceback": error["traceback"],
"current_inputs": error["current_inputs"], "current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"], "current_outputs": error["current_outputs"],
} }
self.server.send_sync("execution_error", mes, self.server.client_id) self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed # Next, remove the subsequent outputs since they will not be executed
to_delete = [] to_delete = []
@ -381,8 +393,8 @@ class PromptExecutor:
else: else:
self.server.client_id = None self.server.client_id = None
if self.server.client_id is not None: self.status_messages = []
self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id) self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): with torch.inference_mode():
# delete cached outputs if nodes don't exist for them # delete cached outputs if nodes don't exist for them
@ -415,9 +427,9 @@ class PromptExecutor:
del d del d
model_management.cleanup_models() model_management.cleanup_models()
if self.server.client_id is not None: self.add_message("execution_cached",
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, {"nodes": list(current_outputs), "prompt_id": prompt_id},
self.server.client_id) broadcast=False)
executed = set() executed = set()
output_node_id = None output_node_id = None
to_execute = [] to_execute = []
@ -434,11 +446,9 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in # This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the # the actual SD code, instead it will report the node where the
# error was raised # error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
self.outputs_ui, self.object_storage) if self.success is not True:
if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
self.handle_execution_error( prompt_id,
prompt, current_outputs, executed, error, ex)
break break
for x in executed: for x in executed:
@ -779,6 +789,7 @@ class PromptQueue:
self.queue = [] self.queue = []
self.currently_running = {} self.currently_running = {}
self.history = {} self.history = {}
self.flags = {}
server.prompt_queue = self server.prompt_queue = self
def size(self) -> int: def size(self) -> int:
@ -803,15 +814,28 @@ class PromptQueue:
self.server.queue_updated() self.server.queue_updated()
return copy.deepcopy(item_with_future.queue_tuple), task_id return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id, outputs: dict): class ExecutionStatus(NamedTuple):
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
def task_done(self, item_id, outputs: dict,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex: with self.mutex:
queue_item = self.currently_running.pop(item_id) queue_item = self.currently_running.pop(item_id)
prompt = queue_item.queue_tuple prompt = queue_item.queue_tuple
if len(self.history) > MAXIMUM_HISTORY_SIZE: if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history))) self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": time.time()}
for o in outputs: status_dict: Optional[dict] = None
self.history[prompt[1]]["outputs"][o] = outputs[o] if status is not None:
status_dict = copy.deepcopy(status._asdict())
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
'status': status_dict,
}
self.server.queue_updated() self.server.queue_updated()
if queue_item.completed: if queue_item.completed:
queue_item.completed.set_result(outputs) queue_item.completed.set_result(outputs)
@ -880,3 +904,17 @@ class PromptQueue:
def delete_history_item(self, id_to_delete: int): def delete_history_item(self, id_to_delete: int):
with self.mutex: with self.mutex:
self.history.pop(id_to_delete, None) self.history.pop(id_to_delete, None)
def set_flag(self, name, data):
with self.mutex:
self.flags[name] = data
self.not_empty.notify()
def get_flags(self, reset=True):
with self.mutex:
if reset:
ret = self.flags
self.flags = {}
return ret
else:
return self.flags.copy()

View File

@ -31,11 +31,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache = {} filename_list_cache = {}
@ -139,15 +142,27 @@ def recursive_search(directory, excluded_dir_names=None):
excluded_dir_names = [] excluded_dir_names = []
result = [] result = []
dirs = {directory: os.path.getmtime(directory)} dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames: for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path) result.append(relative_path)
for d in subdirs: for d in subdirs:
path = os.path.join(dirpath, d) path = os.path.join(dirpath, d)
dirs[path] = os.path.getmtime(path) try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):

View File

@ -89,7 +89,7 @@ def prompt_worker(q, _server):
gc_collect_interval = 10.0 gc_collect_interval = 10.0
current_time = 0.0 current_time = 0.0
while True: while True:
timeout = None timeout = 1000.0
if need_gc: if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -102,14 +102,32 @@ def prompt_worker(q, _server):
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id,
e.outputs_ui,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if _server.client_id is not None: if _server.client_id is not None:
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id) _server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time)) print("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
if flags.get("unload_models", free_memory):
model_management.unload_all_models()
need_gc = True
last_gc_collect = 0
if free_memory:
e.reset()
need_gc = True
last_gc_collect = 0
if need_gc: if need_gc:
current_time = time.perf_counter() current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval: if (current_time - last_gc_collect) > gc_collect_interval:

View File

@ -35,6 +35,7 @@ from ..vendor.appdirs import user_data_dir
nodes = import_all_nodes_in_workspace() nodes = import_all_nodes_in_workspace()
from ..app.user_manager import UserManager
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -91,6 +92,7 @@ class PromptServer():
mimetypes.init() mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.supports = ["custom_nodes_from_web"] self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
@ -532,6 +534,17 @@ class PromptServer():
model_management.interrupt_current_processing() model_management.interrupt_current_processing()
return web.Response(status=200) return web.Response(status=200)
@routes.post("/free")
async def post_free(request):
json_data = await request.json()
unload_models = json_data.get("unload_models", False)
free_memory = json_data.get("free_memory", False)
if unload_models:
self.prompt_queue.set_flag("unload_models", unload_models)
if free_memory:
self.prompt_queue.set_flag("free_memory", free_memory)
return web.Response(status=200)
@routes.post("/history") @routes.post("/history")
async def post_history(request): async def post_history(request):
json_data = await request.json() json_data = await request.json()
@ -671,6 +684,7 @@ class PromptServer():
return web.json_response(prompt, status=200) return web.json_response(prompt, status=200)
def add_routes(self): def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes) self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items(): for name, dir in nodes.EXTENSION_WEB_DIRS.items():
@ -768,14 +782,9 @@ class PromptServer():
site = web.TCPSite(runner, address, port) site = web.TCPSite(runner, address, port)
await site.start() await site.start()
address_to_print = 'localhost'
if address == '' or address == '0.0.0.0':
address = '0.0.0.0'
else:
address_to_print = address
if verbose: if verbose:
print("Starting server\n") print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address_to_print, port)) print("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)

View File

@ -1,4 +1,3 @@
import enum
import torch import torch
import math import math
from . import utils from . import utils

View File

@ -1,7 +1,6 @@
import torch import torch
import math import math
import os import os
import contextlib
from . import utils from . import utils
from . import model_management from . import model_management
@ -127,7 +126,10 @@ class ControlBase:
if o[i] is None: if o[i] is None:
o[i] = prev_val o[i] = prev_val
else: else:
o[i] += prev_val if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):

View File

@ -1,7 +1,7 @@
import math import math
import torch import torch
from torch import nn, einsum from torch import nn
from .ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention
from inspect import isfunction from inspect import isfunction

View File

@ -1,12 +1,9 @@
from inspect import isfunction
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@ -176,6 +173,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min=kv_chunk_size_min, kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False, use_checkpoint=False,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
mask=mask,
) )
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
@ -238,6 +236,12 @@ def attention_split(q, k, v, heads, mask=None):
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
del s1 del s1
first_op_done = True first_op_done = True
@ -293,11 +297,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(q, k, v), (q, k, v),
) )
# actually compute the attention, what we cannot get enough of if mask is not None:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if exists(mask):
raise NotImplementedError
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
@ -322,7 +329,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
@ -338,15 +344,18 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled(): optimized_attention_masked = optimized_attention
optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False): def optimized_attention_for_device(device, mask=False, small_input=False):
if device == torch.device("cpu"): #TODO if small_input:
if model_management.pytorch_attention_enabled(): if model_management.pytorch_attention_enabled():
return attention_pytorch return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else: else:
return attention_basic return attention_basic
if device == torch.device("cpu"):
return attention_sub_quad
if mask: if mask:
return optimized_attention_masked return optimized_attention_masked

View File

@ -1,12 +1,9 @@
from abc import abstractmethod from abc import abstractmethod
import math
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from functools import partial
from .util import ( from .util import (
checkpoint, checkpoint,

View File

@ -23,13 +23,13 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x return x
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None, seed=None):
if noise_level is None: if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else: else:
assert isinstance(noise_level, torch.Tensor) assert isinstance(noise_level, torch.Tensor)
x = self.scale(x) x = self.scale(x)
z = self.q_sample(x, noise_level) z = self.q_sample(x, noise_level, seed=seed)
z = self.unscale(z) z = self.unscale(z)
noise_level = self.time_embed(noise_level) noise_level = self.time_embed(noise_level)
return z, noise_level return z, noise_level

View File

@ -61,6 +61,7 @@ def _summarize_chunk(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> AttnChunk: ) -> AttnChunk:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -84,6 +85,8 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
attn_weights -= max_score attn_weights -= max_score
if mask is not None:
attn_weights += mask
torch.exp(attn_weights, out=attn_weights) torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype) exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value)
@ -96,11 +99,12 @@ def _query_chunk_attention(
value: Tensor, value: Tensor,
summarize_chunk: SummarizeChunk, summarize_chunk: SummarizeChunk,
kv_chunk_size: int, kv_chunk_size: int,
mask,
) -> Tensor: ) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape _, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk: def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
key_chunk = dynamic_slice( key_chunk = dynamic_slice(
key_t, key_t,
(0, 0, chunk_idx), (0, 0, chunk_idx),
@ -111,10 +115,13 @@ def _query_chunk_attention(
(0, chunk_idx, 0), (0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head) (batch_x_heads, kv_chunk_size, v_channels_per_head)
) )
return summarize_chunk(query, key_chunk, value_chunk) if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
chunks: List[AttnChunk] = [ chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
] ]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk chunk_values, chunk_weights, chunk_max = acc_chunk
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> Tensor: ) -> Tensor:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
beta=0, beta=0,
) )
if mask is not None:
attn_scores += mask
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None, kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True, use_checkpoint=True,
upcast_attention=False, upcast_attention=False,
mask = None,
): ):
"""Computes efficient dot-product attention given query, transposed key, and value. """Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in This is efficient version of attention presented in
@ -209,6 +220,9 @@ def efficient_dot_product_attention(
if kv_chunk_size_min is not None: if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
if mask is not None and len(mask.shape) == 2:
mask = mask.unsqueeze(0)
def get_query_chunk(chunk_idx: int) -> Tensor: def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice( return dynamic_slice(
query, query,
@ -216,6 +230,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
) )
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
query=query, query=query,
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=mask,
) )
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size)) ) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1) ], dim=1)
return res return res

View File

@ -6,7 +6,6 @@ from . import model_management
from . import conds from . import conds
from . import ops from . import ops
from enum import Enum from enum import Enum
import contextlib
from . import utils from . import utils
class ModelType(Enum): class ModelType(Enum):
@ -100,11 +99,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model: if self.inpaint_model:
concat_keys = ("mask", "masked_image") concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
latent_image = kwargs.get("latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
@ -117,9 +134,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys: for ck in concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device)) cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else: else:
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
@ -159,19 +176,28 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent): def process_latent_out(self, latent):
return self.latent_format.process_out(latent) return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
if self.model_type == ModelType.V_PREDICTION: if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([]) unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} for sd in extra_sds:
unet_state_dict.update(sd)
return unet_state_dict
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.inpaint_model = True
@ -190,7 +216,7 @@ class BaseModel(torch.nn.Module):
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = [] adm_inputs = []
weights = [] weights = []
noise_aug = [] noise_aug = []
@ -199,7 +225,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
weight = unclip_cond["strength"] weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"] noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight) weights.append(weight)
noise_aug.append(noise_augment) noise_aug.append(noise_augment)
@ -225,11 +251,11 @@ class SD21UNCLIP(BaseModel):
if unclip_conditioning is None: if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels)) return torch.zeros((1, self.adm_channels))
else: else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
def sdxl_pooled(args, noise_augmentor): def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args: if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
else: else:
return args["pooled_output"] return args["pooled_output"]

View File

@ -176,7 +176,7 @@ try:
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
if is_intel_xpu(): if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:

View File

@ -166,6 +166,26 @@ class ConditioningSetAreaPercentage:
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaStrength:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, )
class ConditioningSetMask: class ConditioningSetMask:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -341,6 +361,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -1400,6 +1476,8 @@ class LoadImage:
output_masks = [] output_masks = []
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i) i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
@ -1455,6 +1533,8 @@ class LoadImageMask:
i = Image.open(image_path) i = Image.open(image_path)
i = ImageOps.exif_transpose(i) i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"): if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA") i = i.convert("RGBA")
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -1609,10 +1689,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering): def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size() d1, d2, d3, d4 = image.size()
new_image = torch.zeros( new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4), (d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32, dtype=torch.float32,
) ) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones( mask = torch.ones(
@ -1678,6 +1759,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningConcat": ConditioningConcat, "ConditioningConcat": ConditioningConcat,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
"ConditioningSetMask": ConditioningSetMask, "ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
@ -1704,6 +1786,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader, "GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply, "GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,

View File

@ -1,10 +1,9 @@
import torch import torch
from contextlib import contextmanager from . import model_management
import comfy.model_management
def cast_bias_weight(s, input): def cast_bias_weight(s, input):
bias = None bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device) non_blocking = model_management.device_supports_non_blocking(input.device)
if s.bias is not None: if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)

View File

@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" """ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)

View File

@ -531,7 +531,14 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) clip_sd = None
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
utils.save_torch_file(sd, output_path, metadata=metadata) utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -7,7 +7,6 @@ import traceback
import zipfile import zipfile
from . import model_management from . import model_management
from pkg_resources import resource_filename from pkg_resources import resource_filename
import contextlib
from . import clip_model from . import clip_model
import json import json

View File

@ -65,6 +65,12 @@ class BASE:
replace_prefix = {"": "cond_stage_model."} replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict): def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."} replace_prefix = {"": "model.diffusion_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -1,6 +1,7 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"; import { api } from "../../scripts/api.js";
import { mergeIfValid } from "./widgetInputs.js"; import { mergeIfValid } from "./widgetInputs.js";
import { ManageGroupDialog } from "./groupNodeManage.js";
const GROUP = Symbol(); const GROUP = Symbol();
@ -61,11 +62,7 @@ class GroupNodeBuilder {
); );
return; return;
case Workflow.InUse.Registered: case Workflow.InUse.Registered:
if ( if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) {
!confirm(
"An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?"
)
) {
return; return;
} }
break; break;
@ -151,6 +148,8 @@ export class GroupNodeConfig {
this.primitiveDefs = {}; this.primitiveDefs = {};
this.widgetToPrimitive = {}; this.widgetToPrimitive = {};
this.primitiveToWidget = {}; this.primitiveToWidget = {};
this.nodeInputs = {};
this.outputVisibility = [];
} }
async registerType(source = "workflow") { async registerType(source = "workflow") {
@ -158,6 +157,7 @@ export class GroupNodeConfig {
output: [], output: [],
output_name: [], output_name: [],
output_is_list: [], output_is_list: [],
output_is_hidden: [],
name: source + "/" + this.name, name: source + "/" + this.name,
display_name: this.name, display_name: this.name,
category: "group nodes" + ("/" + source), category: "group nodes" + ("/" + source),
@ -277,8 +277,7 @@ export class GroupNodeConfig {
} }
if (input.widget) { if (input.widget) {
const targetDef = globalDefs[node.type]; const targetDef = globalDefs[node.type];
const targetWidget = const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
const widget = [targetWidget[0], config]; const widget = [targetWidget[0], config];
const res = mergeIfValid( const res = mergeIfValid(
@ -330,7 +329,8 @@ export class GroupNodeConfig {
} }
getInputConfig(node, inputName, seenInputs, config, extra) { getInputConfig(node, inputName, seenInputs, config, extra) {
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName];
let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
let key = name; let key = name;
let prefix = ""; let prefix = "";
// Special handling for primitive to include the title if it is set rather than just "value" // Special handling for primitive to include the title if it is set rather than just "value"
@ -349,14 +349,14 @@ export class GroupNodeConfig {
} }
if (config[0] === "IMAGEUPLOAD") { if (config[0] === "IMAGEUPLOAD") {
if (!extra) extra = {}; if (!extra) extra = {};
extra.widget = `${prefix}${config[1]?.widget ?? "image"}`; extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image";
} }
if (extra) { if (extra) {
config = [config[0], { ...config[1], ...extra }]; config = [config[0], { ...config[1], ...extra }];
} }
return { name, config }; return { name, config, customConfig };
} }
processWidgetInputs(inputs, node, inputNames, seenInputs) { processWidgetInputs(inputs, node, inputNames, seenInputs) {
@ -366,9 +366,7 @@ export class GroupNodeConfig {
for (const inputName of inputNames) { for (const inputName of inputNames) {
let widgetType = app.getWidgetType(inputs[inputName], inputName); let widgetType = app.getWidgetType(inputs[inputName], inputName);
if (widgetType) { if (widgetType) {
const convertedIndex = node.inputs?.findIndex( const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName);
(inp) => inp.name === inputName && inp.widget?.name === inputName
);
if (convertedIndex > -1) { if (convertedIndex > -1) {
// This widget has been converted to a widget // This widget has been converted to a widget
// We need to store this in the correct position so link ids line up // We need to store this in the correct position so link ids line up
@ -424,6 +422,7 @@ export class GroupNodeConfig {
} }
processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) { processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) {
this.nodeInputs[node.index] = {};
for (let i = 0; i < slots.length; i++) { for (let i = 0; i < slots.length; i++) {
const inputName = slots[i]; const inputName = slots[i];
if (linksTo[i]) { if (linksTo[i]) {
@ -432,7 +431,11 @@ export class GroupNodeConfig {
continue; continue;
} }
const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
this.nodeInputs[node.index][inputName] = name;
if(customConfig?.visible === false) continue;
this.nodeDef.input.required[name] = config; this.nodeDef.input.required[name] = config;
inputMap[i] = this.inputCount++; inputMap[i] = this.inputCount++;
} }
@ -452,6 +455,7 @@ export class GroupNodeConfig {
const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], { const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], {
defaultInput: true, defaultInput: true,
}); });
this.nodeDef.input.required[name] = config; this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName }; this.newToOldWidgetMap[name] = { node, inputName };
@ -477,9 +481,7 @@ export class GroupNodeConfig {
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list // Converted inputs have to be processed after all other nodes as they'll be at the end of the list
this.#convertedToProcess.push(() => this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs));
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
);
return inputMapping; return inputMapping;
} }
@ -490,8 +492,12 @@ export class GroupNodeConfig {
// Add outputs // Add outputs
for (let outputId = 0; outputId < def.output.length; outputId++) { for (let outputId = 0; outputId < def.output.length; outputId++) {
const linksFrom = this.linksFrom[node.index]; const linksFrom = this.linksFrom[node.index];
if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) { // If this output is linked internally we flag it to hide
// This output is linked internally so we can skip it const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId];
const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId];
const visible = customConfig?.visible ?? !hasLink;
this.outputVisibility.push(visible);
if (!visible) {
continue; continue;
} }
@ -500,11 +506,15 @@ export class GroupNodeConfig {
this.nodeDef.output.push(def.output[outputId]); this.nodeDef.output.push(def.output[outputId]);
this.nodeDef.output_is_list.push(def.output_is_list[outputId]); this.nodeDef.output_is_list.push(def.output_is_list[outputId]);
let label = def.output_name?.[outputId] ?? def.output[outputId]; let label = customConfig?.name;
const output = node.outputs.find((o) => o.name === label); if (!label) {
if (output?.label) { label = def.output_name?.[outputId] ?? def.output[outputId];
label = output.label; const output = node.outputs.find((o) => o.name === label);
if (output?.label) {
label = output.label;
}
} }
let name = label; let name = label;
if (name in seenOutputs) { if (name in seenOutputs) {
const prefix = `${node.title ?? node.type} `; const prefix = `${node.title ?? node.type} `;
@ -677,6 +687,25 @@ export class GroupNodeHandler {
return this.innerNodes; return this.innerNodes;
}; };
this.node.recreate = async () => {
const id = this.node.id;
const sz = this.node.size;
const nodes = this.node.convertToNodes();
const groupNode = LiteGraph.createNode(this.node.type);
groupNode.id = id;
// Reuse the existing nodes for this instance
groupNode.setInnerNodes(nodes);
groupNode[GROUP].populateWidgets();
app.graph.add(groupNode);
groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])];
// Remove all converted nodes and relink them
groupNode[GROUP].replaceNodes(nodes);
return groupNode;
};
this.node.convertToNodes = () => { this.node.convertToNodes = () => {
const addInnerNodes = () => { const addInnerNodes = () => {
const backup = localStorage.getItem("litegrapheditor_clipboard"); const backup = localStorage.getItem("litegrapheditor_clipboard");
@ -769,6 +798,7 @@ export class GroupNodeHandler {
const slot = node.inputs[groupSlotId]; const slot = node.inputs[groupSlotId];
if (slot.link == null) continue; if (slot.link == null) continue;
const link = app.graph.links[slot.link]; const link = app.graph.links[slot.link];
if (!link) continue;
// connect this node output to the input of another node // connect this node output to the input of another node
const originNode = app.graph.getNodeById(link.origin_id); const originNode = app.graph.getNodeById(link.origin_id);
originNode.connect(link.origin_slot, newNode, +innerInputId); originNode.connect(link.origin_slot, newNode, +innerInputId);
@ -806,12 +836,23 @@ export class GroupNodeHandler {
let optionIndex = options.findIndex((o) => o.content === "Outputs"); let optionIndex = options.findIndex((o) => o.content === "Outputs");
if (optionIndex === -1) optionIndex = options.length; if (optionIndex === -1) optionIndex = options.length;
else optionIndex++; else optionIndex++;
options.splice(optionIndex, 0, null, { options.splice(
content: "Convert to nodes", optionIndex,
callback: () => { 0,
return this.convertToNodes(); null,
{
content: "Convert to nodes",
callback: () => {
return this.convertToNodes();
},
}, },
}); {
content: "Manage Group Node",
callback: () => {
new ManageGroupDialog(app).show(this.type);
},
}
);
}; };
// Draw custom collapse icon to identity this as a group // Draw custom collapse icon to identity this as a group
@ -843,6 +884,7 @@ export class GroupNodeHandler {
const r = onDrawForeground?.apply?.(this, arguments); const r = onDrawForeground?.apply?.(this, arguments);
if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) { if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) {
const n = groupData.nodes[this.runningInternalNodeId]; const n = groupData.nodes[this.runningInternalNodeId];
if(!n) return;
const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`; const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`;
ctx.save(); ctx.save();
ctx.font = "12px sans-serif"; ctx.font = "12px sans-serif";
@ -865,6 +907,28 @@ export class GroupNodeHandler {
return onExecutionStart?.apply(this, arguments); return onExecutionStart?.apply(this, arguments);
}; };
const self = this;
const onNodeCreated = this.node.onNodeCreated;
this.node.onNodeCreated = function () {
const config = self.groupData.nodeData.config;
if (config) {
for (const n in config) {
const inputs = config[n]?.input;
for (const w in inputs) {
if (inputs[w].visible !== false) continue;
const widgetName = self.groupData.oldToNewWidgetMap[n][w];
const widget = this.widgets.find((w) => w.name === widgetName);
if (widget) {
widget.type = "hidden";
widget.computeSize = () => [0, -4];
}
}
}
}
return onNodeCreated?.apply(this, arguments);
};
function handleEvent(type, getId, getEvent) { function handleEvent(type, getId, getEvent) {
const handler = ({ detail }) => { const handler = ({ detail }) => {
const id = getId(detail); const id = getId(detail);
@ -902,6 +966,26 @@ export class GroupNodeHandler {
api.removeEventListener("executing", executing); api.removeEventListener("executing", executing);
api.removeEventListener("executed", executed); api.removeEventListener("executed", executed);
}; };
this.node.refreshComboInNode = (defs) => {
// Update combo widget options
for (const widgetName in this.groupData.newToOldWidgetMap) {
const widget = this.node.widgets.find((w) => w.name === widgetName);
if (widget?.type === "combo") {
const old = this.groupData.newToOldWidgetMap[widgetName];
const def = defs[old.node.type];
const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName];
if (!input) continue;
widget.options.values = input[0];
if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) {
widget.value = widget.options.values[0];
widget.callback(widget.value);
}
}
}
};
} }
updateInnerWidgets() { updateInnerWidgets() {
@ -927,13 +1011,15 @@ export class GroupNodeHandler {
continue; continue;
} else if (innerNode.type === "Reroute") { } else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index]; const rerouteLinks = this.groupData.linksFrom[old.node.index];
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { if (rerouteLinks) {
const node = this.innerNodes[targetNodeId]; for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const input = node.inputs[targetSlot]; const node = this.innerNodes[targetNodeId];
if (input.widget) { const input = node.inputs[targetSlot];
const widget = node.widgets?.find((w) => w.name === input.widget.name); if (input.widget) {
if (widget) { const widget = node.widgets?.find((w) => w.name === input.widget.name);
widget.value = newValue; if (widget) {
widget.value = newValue;
}
} }
} }
} }
@ -975,7 +1061,7 @@ export class GroupNodeHandler {
const [, , targetNodeId, targetNodeSlot] = link; const [, , targetNodeId, targetNodeSlot] = link;
const targetNode = this.groupData.nodeData.nodes[targetNodeId]; const targetNode = this.groupData.nodeData.nodes[targetNodeId];
const inputs = targetNode.inputs; const inputs = targetNode.inputs;
const targetWidget = inputs?.[targetNodeSlot].widget; const targetWidget = inputs?.[targetNodeSlot]?.widget;
if (!targetWidget) return; if (!targetWidget) return;
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
@ -983,13 +1069,12 @@ export class GroupNodeHandler {
if (v == null) return; if (v == null) return;
const widgetName = Object.values(map)[0]; const widgetName = Object.values(map)[0];
const widget = this.node.widgets.find(w => w.name === widgetName); const widget = this.node.widgets.find((w) => w.name === widgetName);
if(widget) { if (widget) {
widget.value = v; widget.value = v;
} }
} }
populateWidgets() { populateWidgets() {
if (!this.node.widgets) return; if (!this.node.widgets) return;
@ -1080,7 +1165,7 @@ export class GroupNodeHandler {
} }
static getGroupData(node) { static getGroupData(node) {
return node.constructor?.nodeData?.[GROUP]; return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP];
} }
static isGroupNode(node) { static isGroupNode(node) {
@ -1112,7 +1197,7 @@ export class GroupNodeHandler {
} }
function addConvertToGroupOptions() { function addConvertToGroupOptions() {
function addOption(options, index) { function addConvertOption(options, index) {
const selected = Object.values(app.canvas.selected_nodes ?? {}); const selected = Object.values(app.canvas.selected_nodes ?? {});
const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n)); const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n));
options.splice(index + 1, null, { options.splice(index + 1, null, {
@ -1124,12 +1209,25 @@ function addConvertToGroupOptions() {
}); });
} }
function addManageOption(options, index) {
const groups = app.graph.extra?.groupNodes;
const disabled = !groups || !Object.keys(groups).length;
options.splice(index + 1, null, {
content: `Manage Group Nodes`,
disabled,
callback: () => {
new ManageGroupDialog(app).show();
},
});
}
// Add to canvas // Add to canvas
const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions;
LGraphCanvas.prototype.getCanvasMenuOptions = function () { LGraphCanvas.prototype.getCanvasMenuOptions = function () {
const options = getCanvasMenuOptions.apply(this, arguments); const options = getCanvasMenuOptions.apply(this, arguments);
const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length; const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length;
addOption(options, index); addConvertOption(options, index);
addManageOption(options, index + 1);
return options; return options;
}; };
@ -1139,7 +1237,7 @@ function addConvertToGroupOptions() {
const options = getNodeMenuOptions.apply(this, arguments); const options = getNodeMenuOptions.apply(this, arguments);
if (!GroupNodeHandler.isGroupNode(node)) { if (!GroupNodeHandler.isGroupNode(node)) {
const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1; const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1;
addOption(options, index); addConvertOption(options, index);
} }
return options; return options;
}; };
@ -1167,6 +1265,14 @@ const ext = {
node[GROUP] = new GroupNodeHandler(node); node[GROUP] = new GroupNodeHandler(node);
} }
}, },
async refreshComboInNodes(defs) {
// Re-register group nodes so new ones are created with the correct options
Object.assign(globalDefs, defs);
const nodes = app.graph.extra?.groupNodes;
if (nodes) {
await GroupNodeConfig.registerFromWorkflow(nodes, {});
}
}
}; };
app.registerExtension(ext); app.registerExtension(ext);

View File

@ -0,0 +1,149 @@
.comfy-group-manage {
background: var(--bg-color);
color: var(--fg-color);
padding: 0;
font-family: Arial, Helvetica, sans-serif;
border-color: black;
margin: 20vh auto;
max-height: 60vh;
}
.comfy-group-manage-outer {
max-height: 60vh;
min-width: 500px;
display: flex;
flex-direction: column;
}
.comfy-group-manage-outer > header {
display: flex;
align-items: center;
gap: 10px;
justify-content: space-between;
background: var(--comfy-menu-bg);
padding: 15px 20px;
}
.comfy-group-manage-outer > header select {
background: var(--comfy-input-bg);
border: 1px solid var(--border-color);
color: var(--input-text);
padding: 5px 10px;
border-radius: 5px;
}
.comfy-group-manage h2 {
margin: 0;
font-weight: normal;
}
.comfy-group-manage main {
display: flex;
overflow: hidden;
}
.comfy-group-manage .drag-handle {
font-weight: bold;
}
.comfy-group-manage-list {
border-right: 1px solid var(--comfy-menu-bg);
}
.comfy-group-manage-list ul {
margin: 40px 0 0;
padding: 0;
list-style: none;
}
.comfy-group-manage-list-items {
max-height: 70vh;
overflow-y: scroll;
overflow-x: hidden;
}
.comfy-group-manage-list li {
display: flex;
padding: 10px 20px 10px 10px;
cursor: pointer;
align-items: center;
gap: 5px;
}
.comfy-group-manage-list div {
display: flex;
flex-direction: column;
}
.comfy-group-manage-list li:not(.selected):hover div {
text-decoration: underline;
}
.comfy-group-manage-list li.selected {
background: var(--border-color);
}
.comfy-group-manage-list li span {
opacity: 0.7;
font-size: smaller;
}
.comfy-group-manage-node {
flex: auto;
background: var(--border-color);
display: flex;
flex-direction: column;
}
.comfy-group-manage-node > div {
overflow: auto;
}
.comfy-group-manage-node header {
display: flex;
background: var(--bg-color);
height: 40px;
}
.comfy-group-manage-node header a {
text-align: center;
flex: auto;
border-right: 1px solid var(--comfy-menu-bg);
border-bottom: 1px solid var(--comfy-menu-bg);
padding: 10px;
cursor: pointer;
font-size: 15px;
}
.comfy-group-manage-node header a:last-child {
border-right: none;
}
.comfy-group-manage-node header a:not(.active):hover {
text-decoration: underline;
}
.comfy-group-manage-node header a.active {
background: var(--border-color);
border-bottom: none;
}
.comfy-group-manage-node-page {
display: none;
overflow: auto;
}
.comfy-group-manage-node-page.active {
display: block;
}
.comfy-group-manage-node-page div {
padding: 10px;
display: flex;
align-items: center;
gap: 10px;
}
.comfy-group-manage-node-page input {
border: none;
color: var(--input-text);
background: var(--comfy-input-bg);
padding: 5px 10px;
}
.comfy-group-manage-node-page input[type="text"] {
flex: auto;
}
.comfy-group-manage-node-page label {
display: flex;
gap: 5px;
align-items: center;
}
.comfy-group-manage footer {
border-top: 1px solid var(--comfy-menu-bg);
padding: 10px;
display: flex;
gap: 10px;
}
.comfy-group-manage footer button {
font-size: 14px;
padding: 5px 10px;
border-radius: 0;
}
.comfy-group-manage footer button:first-child {
margin-right: auto;
}

View File

@ -0,0 +1,422 @@
import { $el, ComfyDialog } from "../../scripts/ui.js";
import { DraggableList } from "../../scripts/ui/draggableList.js";
import { addStylesheet } from "../../scripts/utils.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
addStylesheet(import.meta.url);
const ORDER = Symbol();
function merge(target, source) {
if (typeof target === "object" && typeof source === "object") {
for (const key in source) {
const sv = source[key];
if (typeof sv === "object") {
let tv = target[key];
if (!tv) tv = target[key] = {};
merge(tv, source[key]);
} else {
target[key] = sv;
}
}
}
return target;
}
export class ManageGroupDialog extends ComfyDialog {
/** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */
tabs = {};
/** @type { number | null | undefined } */
selectedNodeIndex;
/** @type { keyof ManageGroupDialog["tabs"] } */
selectedTab = "Inputs";
/** @type { string | undefined } */
selectedGroup;
/** @type { Record<string, Record<string, Record<string, { name?: string | undefined, visible?: boolean | undefined }>>> } */
modifications = {};
get selectedNodeInnerIndex() {
return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex;
}
constructor(app) {
super();
this.app = app;
this.element = $el("dialog.comfy-group-manage", {
parent: document.body,
});
}
changeTab(tab) {
this.tabs[this.selectedTab].tab.classList.remove("active");
this.tabs[this.selectedTab].page.classList.remove("active");
this.tabs[tab].tab.classList.add("active");
this.tabs[tab].page.classList.add("active");
this.selectedTab = tab;
}
changeNode(index, force) {
if (!force && this.selectedNodeIndex === index) return;
if (this.selectedNodeIndex != null) {
this.nodeItems[this.selectedNodeIndex].classList.remove("selected");
}
this.nodeItems[index].classList.add("selected");
this.selectedNodeIndex = index;
if (!this.buildInputsPage() && this.selectedTab === "Inputs") {
this.changeTab("Widgets");
}
if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") {
this.changeTab("Outputs");
}
if (!this.buildOutputsPage() && this.selectedTab === "Outputs") {
this.changeTab("Inputs");
}
this.changeTab(this.selectedTab);
}
getGroupData() {
this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup];
this.groupNodeDef = this.groupNodeType.nodeData;
this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType);
}
changeGroup(group, reset = true) {
this.selectedGroup = group;
this.getGroupData();
const nodes = this.groupData.nodeData.nodes;
this.nodeItems = nodes.map((n, i) =>
$el(
"li.draggable-item",
{
dataset: {
nodeindex: n.index + "",
},
onclick: () => {
this.changeNode(i);
},
},
[
$el("span.drag-handle"),
$el(
"div",
{
textContent: n.title ?? n.type,
},
n.title
? $el("span", {
textContent: n.type,
})
: []
),
]
)
);
this.innerNodesList.replaceChildren(...this.nodeItems);
if (reset) {
this.selectedNodeIndex = null;
this.changeNode(0);
} else {
const items = this.draggable.getAllItems();
let index = items.findIndex(item => item.classList.contains("selected"));
if(index === -1) index = this.selectedNodeIndex;
this.changeNode(index, true);
}
const ordered = [...nodes];
this.draggable?.dispose();
this.draggable = new DraggableList(this.innerNodesList, "li");
this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => {
if (oldPosition === newPosition) return;
ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]);
for (let i = 0; i < ordered.length; i++) {
this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i });
}
});
}
storeModification({ nodeIndex, section, prop, value }) {
const groupMod = (this.modifications[this.selectedGroup] ??= {});
const nodesMod = (groupMod.nodes ??= {});
const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {});
const typeMod = (nodeMod[section] ??= {});
if (typeof value === "object") {
const objMod = (typeMod[prop] ??= {});
Object.assign(objMod, value);
} else {
typeMod[prop] = value;
}
}
getEditElement(section, prop, value, placeholder, checked, checkable = true) {
if (value === placeholder) value = "";
const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop];
if (mods) {
if (mods.name != null) {
value = mods.name;
}
if (mods.visible != null) {
checked = mods.visible;
}
}
return $el("div", [
$el("input", {
value,
placeholder,
type: "text",
onchange: (e) => {
this.storeModification({ section, prop, value: { name: e.target.value } });
},
}),
$el("label", { textContent: "Visible" }, [
$el("input", {
type: "checkbox",
checked,
disabled: !checkable,
onchange: (e) => {
this.storeModification({ section, prop, value: { visible: !!e.target.checked } });
},
}),
]),
]);
}
buildWidgetsPage() {
const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex];
const items = Object.keys(widgets ?? {});
const type = app.graph.extra.groupNodes[this.selectedGroup];
const config = type.config?.[this.selectedNodeInnerIndex]?.input;
this.widgetsPage.replaceChildren(
...items.map((oldName) => {
return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false);
})
);
return !!items.length;
}
buildInputsPage() {
const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex];
const items = Object.keys(inputs ?? {});
const type = app.graph.extra.groupNodes[this.selectedGroup];
const config = type.config?.[this.selectedNodeInnerIndex]?.input;
this.inputsPage.replaceChildren(
...items
.map((oldName) => {
let value = inputs[oldName];
if (!value) {
return;
}
return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false);
})
.filter(Boolean)
);
return !!items.length;
}
buildOutputsPage() {
const nodes = this.groupData.nodeData.nodes;
const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]);
const outputs = innerNodeDef?.output ?? [];
const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex];
const type = app.graph.extra.groupNodes[this.selectedGroup];
const config = type.config?.[this.selectedNodeInnerIndex]?.output;
const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex];
const checkable = node.type !== "PrimitiveNode";
this.outputsPage.replaceChildren(
...outputs
.map((type, slot) => {
const groupOutputIndex = groupOutputs?.[slot];
const oldName = innerNodeDef.output_name?.[slot] ?? type;
let value = config?.[slot]?.name;
const visible = config?.[slot]?.visible || groupOutputIndex != null;
if (!value || value === oldName) {
value = "";
}
return this.getEditElement("output", slot, value, oldName, visible, checkable);
})
.filter(Boolean)
);
return !!outputs.length;
}
show(type) {
const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b));
this.innerNodesList = $el("ul.comfy-group-manage-list-items");
this.widgetsPage = $el("section.comfy-group-manage-node-page");
this.inputsPage = $el("section.comfy-group-manage-node-page");
this.outputsPage = $el("section.comfy-group-manage-node-page");
const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]);
this.tabs = [
["Inputs", this.inputsPage],
["Widgets", this.widgetsPage],
["Outputs", this.outputsPage],
].reduce((p, [name, page]) => {
p[name] = {
tab: $el("a", {
onclick: () => {
this.changeTab(name);
},
textContent: name,
}),
page,
};
return p;
}, {});
const outer = $el("div.comfy-group-manage-outer", [
$el("header", [
$el("h2", "Group Nodes"),
$el(
"select",
{
onchange: (e) => {
this.changeGroup(e.target.value);
},
},
groupNodes.map((g) =>
$el("option", {
textContent: g,
selected: "workflow/" + g === type,
value: g,
})
)
),
]),
$el("main", [
$el("section.comfy-group-manage-list", this.innerNodesList),
$el("section.comfy-group-manage-node", [
$el(
"header",
Object.values(this.tabs).map((t) => t.tab)
),
pages,
]),
]),
$el("footer", [
$el(
"button.comfy-btn",
{
onclick: (e) => {
const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup);
if (node) {
alert("This group node is in use in the current workflow, please first remove these.");
return;
}
if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) {
delete app.graph.extra.groupNodes[this.selectedGroup];
LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup);
}
this.show();
},
},
"Delete Group Node"
),
$el(
"button.comfy-btn",
{
onclick: async () => {
let nodesByType;
let recreateNodes = [];
const types = {};
for (const g in this.modifications) {
const type = app.graph.extra.groupNodes[g];
let config = (type.config ??= {});
let nodeMods = this.modifications[g]?.nodes;
if (nodeMods) {
const keys = Object.keys(nodeMods);
if (nodeMods[keys[0]][ORDER]) {
// If any node is reordered, they will all need sequencing
const orderedNodes = [];
const orderedMods = {};
const orderedConfig = {};
for (const n of keys) {
const order = nodeMods[n][ORDER].order;
orderedNodes[order] = type.nodes[+n];
orderedMods[order] = nodeMods[n];
orderedNodes[order].index = order;
}
// Rewrite links
for (const l of type.links) {
if (l[0] != null) l[0] = type.nodes[l[0]].index;
if (l[2] != null) l[2] = type.nodes[l[2]].index;
}
// Rewrite externals
if (type.external) {
for (const ext of type.external) {
ext[0] = type.nodes[ext[0]];
}
}
// Rewrite modifications
for (const id of keys) {
if (config[id]) {
orderedConfig[type.nodes[id].index] = config[id];
}
delete config[id];
}
type.nodes = orderedNodes;
nodeMods = orderedMods;
type.config = config = orderedConfig;
}
merge(config, nodeMods);
}
types[g] = type;
if (!nodesByType) {
nodesByType = app.graph._nodes.reduce((p, n) => {
p[n.type] ??= [];
p[n.type].push(n);
return p;
}, {});
}
const nodes = nodesByType["workflow/" + g];
if (nodes) recreateNodes.push(...nodes);
}
await GroupNodeConfig.registerFromWorkflow(types, {});
for (const node of recreateNodes) {
node.recreate();
}
this.modifications = {};
this.app.graph.setDirtyCanvas(true, true);
this.changeGroup(this.selectedGroup, false);
},
},
"Save"
),
$el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"),
]),
]);
this.element.replaceChildren(outer);
this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]);
this.element.showModal();
this.element.addEventListener("close", () => {
this.draggable?.dispose();
});
}
}

View File

@ -1,4 +1,5 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { ComfyDialog, $el } from "../../scripts/ui.js"; import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle // Open the manage dialog and Drag and drop elements using the "Name:" label as handle
const id = "Comfy.NodeTemplates"; const id = "Comfy.NodeTemplates";
const file = "comfy.templates.json";
class ManageTemplates extends ComfyDialog { class ManageTemplates extends ComfyDialog {
constructor() { constructor() {
super(); super();
this.load().then((v) => {
this.templates = v;
});
this.element.classList.add("comfy-manage-templates"); this.element.classList.add("comfy-manage-templates");
this.templates = this.load();
this.draggedEl = null; this.draggedEl = null;
this.saveVisualCue = null; this.saveVisualCue = null;
this.emptyImg = new Image(); this.emptyImg = new Image();
this.emptyImg.src = ''; this.emptyImg.src = "";
this.importInput = $el("input", { this.importInput = $el("input", {
type: "file", type: "file",
@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog {
return btns; return btns;
} }
load() { async load() {
const templates = localStorage.getItem(id); let templates = [];
if (templates) { if (app.storageLocation === "server") {
return JSON.parse(templates); if (app.isNewUserSession) {
// New user so migrate existing templates
const json = localStorage.getItem(id);
if (json) {
templates = JSON.parse(json);
}
await api.storeUserData(file, json, { stringify: false });
} else {
const res = await api.getUserData(file);
if (res.status === 200) {
try {
templates = await res.json();
} catch (error) {
}
} else if (res.status !== 404) {
console.error(res.status + " " + res.statusText);
}
}
} else { } else {
return []; const json = localStorage.getItem(id);
if (json) {
templates = JSON.parse(json);
}
} }
return templates ?? [];
} }
store() { async store() {
localStorage.setItem(id, JSON.stringify(this.templates)); if(app.storageLocation === "server") {
const templates = JSON.stringify(this.templates, undefined, 4);
localStorage.setItem(id, templates); // Backwards compatibility
try {
await api.storeUserData(file, templates, { stringify: false });
} catch (error) {
console.error(error);
alert(error.message);
}
} else {
localStorage.setItem(id, JSON.stringify(this.templates));
}
} }
async importAll() { async importAll() {
@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog {
if (file.type === "application/json" || file.name.endsWith(".json")) { if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async () => { reader.onload = async () => {
var importFile = JSON.parse(reader.result); const importFile = JSON.parse(reader.result);
if (importFile && importFile?.templates) { if (importFile?.templates) {
for (const template of importFile.templates) { for (const template of importFile.templates) {
if (template?.name && template?.data) { if (template?.name && template?.data) {
this.templates.push(template); this.templates.push(template);
} }
} }
this.store(); await this.store();
} }
}; };
await reader.readAsText(file); await reader.readAsText(file);
@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog {
e.currentTarget.style.border = "1px dashed transparent"; e.currentTarget.style.border = "1px dashed transparent";
e.currentTarget.removeAttribute("draggable"); e.currentTarget.removeAttribute("draggable");
// rearrange the elements in the localStorage // rearrange the elements
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
var prev_i = el.dataset.id; var prev_i = el.dataset.id;

View File

@ -1,4 +1,5 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"
const MAX_HISTORY = 50; const MAX_HISTORY = 50;
@ -15,6 +16,7 @@ function checkState() {
} }
activeState = clone(currentState); activeState = clone(currentState);
redo.length = 0; redo.length = 0;
api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState }));
} }
} }
@ -92,7 +94,7 @@ const undoRedo = async (e) => {
}; };
const bindInput = (activeEl) => { const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) { for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) { if (`on${evt}` in activeEl) {
const listener = () => { const listener = () => {
@ -106,16 +108,24 @@ const bindInput = (activeEl) => {
} }
}; };
let keyIgnored = false;
window.addEventListener( window.addEventListener(
"keydown", "keydown",
(e) => { (e) => {
requestAnimationFrame(async () => { requestAnimationFrame(async () => {
const activeEl = document.activeElement; let activeEl;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { // If we are auto queue in change mode then we do want to trigger on inputs
// Ignore events on inputs, they have their native history if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") {
return; activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
} }
keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta";
if (keyIgnored) return;
// Check if this is a ctrl+z ctrl+y // Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return; if (await undoRedo(e)) return;
@ -127,11 +137,23 @@ window.addEventListener(
true true
); );
window.addEventListener("keyup", (e) => {
if (keyIgnored) {
keyIgnored = false;
checkState();
}
});
// Handle clicking DOM elements (e.g. widgets) // Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => { window.addEventListener("mouseup", () => {
checkState(); checkState();
}); });
// Handle prompt queue event for dynamic widget changes
api.addEventListener("promptQueued", () => {
checkState();
});
// Handle litegraph clicks // Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp; const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) { LGraphCanvas.prototype.processMouseUp = function (e) {
@ -145,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) {
checkState(); checkState();
return v; return v;
}; };
// Handle litegraph context menu for COMBO widgets
const close = LiteGraph.ContextMenu.prototype.close;
LiteGraph.ContextMenu.prototype.close = function(e) {
const v = close.apply(this, arguments);
checkState();
return v;
}

View File

@ -16,5 +16,33 @@
window.graph = app.graph; window.graph = app.graph;
</script> </script>
</head> </head>
<body class="litegraph"></body> <body class="litegraph">
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
</div>
</body>
</html> </html>

View File

@ -3,7 +3,8 @@
"baseUrl": ".", "baseUrl": ".",
"paths": { "paths": {
"/*": ["./*"] "/*": ["./*"]
} },
"lib": ["DOM", "ES2022"]
}, },
"include": ["."] "include": ["."]
} }

View File

@ -11496,7 +11496,7 @@ LGraphNode.prototype.executeAction = function(action)
} }
timeout_close = setTimeout(function() { timeout_close = setTimeout(function() {
dialog.close(); dialog.close();
}, 500); }, typeof options.hide_on_mouse_leave === "number" ? options.hide_on_mouse_leave : 500);
}); });
// if filtering, check focus changed to comboboxes and prevent closing // if filtering, check focus changed to comboboxes and prevent closing
if (options.do_type_filter){ if (options.do_type_filter){

View File

@ -12,6 +12,13 @@ class ComfyApi extends EventTarget {
} }
fetchApi(route, options) { fetchApi(route, options) {
if (!options) {
options = {};
}
if (!options.headers) {
options.headers = {};
}
options.headers["Comfy-User"] = this.user;
return fetch(this.apiURL(route), options); return fetch(this.apiURL(route), options);
} }
@ -315,6 +322,99 @@ class ComfyApi extends EventTarget {
async interrupt() { async interrupt() {
await this.#postItem("interrupt", null); await this.#postItem("interrupt", null);
} }
/**
* Gets user configuration data and where data should be stored
* @returns { Promise<{ storage: "server" | "browser", users?: Promise<string, unknown>, migrated?: boolean }> }
*/
async getUserConfig() {
return (await this.fetchApi("/users")).json();
}
/**
* Creates a new user
* @param { string } username
* @returns The fetch response
*/
createUser(username) {
return this.fetchApi("/users", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ username }),
});
}
/**
* Gets all setting values for the current user
* @returns { Promise<string, unknown> } A dictionary of id -> value
*/
async getSettings() {
return (await this.fetchApi("/settings")).json();
}
/**
* Gets a setting for the current user
* @param { string } id The id of the setting to fetch
* @returns { Promise<unknown> } The setting value
*/
async getSetting(id) {
return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json();
}
/**
* Stores a dictionary of settings for the current user
* @param { Record<string, unknown> } settings Dictionary of setting id -> value to save
* @returns { Promise<void> }
*/
async storeSettings(settings) {
return this.fetchApi(`/settings`, {
method: "POST",
body: JSON.stringify(settings)
});
}
/**
* Stores a setting for the current user
* @param { string } id The id of the setting to update
* @param { unknown } value The value of the setting
* @returns { Promise<void> }
*/
async storeSetting(id, value) {
return this.fetchApi(`/settings/${encodeURIComponent(id)}`, {
method: "POST",
body: JSON.stringify(value)
});
}
/**
* Gets a user data file for the current user
* @param { string } file The name of the userdata file to load
* @param { RequestInit } [options]
* @returns { Promise<unknown> } The fetch response object
*/
async getUserData(file, options) {
return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options);
}
/**
* Stores a user data file for the current user
* @param { string } file The name of the userdata file to save
* @param { unknown } data The data to save to the file
* @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options]
* @returns { Promise<void> }
*/
async storeUserData(file, data, options = { stringify: true, throwOnError: true }) {
const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, {
method: "POST",
body: options?.stringify ? JSON.stringify(data) : data,
...options,
});
if (resp.status !== 200) {
throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`);
}
}
} }
export const api = new ComfyApi(); export const api = new ComfyApi();

View File

@ -1,5 +1,5 @@
import { ComfyLogging } from "./logging.js"; import { ComfyLogging } from "./logging.js";
import { ComfyWidgets } from "./widgets.js"; import { ComfyWidgets, initWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js"; import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js"; import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js"; import { defaultGraph } from "./defaultGraph.js";
@ -269,6 +269,71 @@ export class ComfyApp {
* @param {*} node The node to add the menu handler * @param {*} node The node to add the menu handler
*/ */
#addNodeContextMenuHandler(node) { #addNodeContextMenuHandler(node) {
function getCopyImageOption(img) {
if (typeof window.ClipboardItem === "undefined") return [];
return [
{
content: "Copy Image",
callback: async () => {
const url = new URL(img.src);
url.searchParams.delete("preview");
const writeImage = async (blob) => {
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob,
}),
]);
};
try {
const data = await fetch(url);
const blob = await data.blob();
try {
await writeImage(blob);
} catch (error) {
// Chrome seems to only support PNG on write, convert and try again
if (blob.type !== "image/png") {
const canvas = $el("canvas", {
width: img.naturalWidth,
height: img.naturalHeight,
});
const ctx = canvas.getContext("2d");
let image;
if (typeof window.createImageBitmap === "undefined") {
image = new Image();
const p = new Promise((resolve, reject) => {
image.onload = resolve;
image.onerror = reject;
}).finally(() => {
URL.revokeObjectURL(image.src);
});
image.src = URL.createObjectURL(blob);
await p;
} else {
image = await createImageBitmap(blob);
}
try {
ctx.drawImage(image, 0, 0);
canvas.toBlob(writeImage, "image/png");
} finally {
if (typeof image.close === "function") {
image.close();
}
}
return;
}
throw error;
}
} catch (error) {
alert("Error copying image: " + (error.message ?? error));
}
},
},
];
}
node.prototype.getExtraMenuOptions = function (_, options) { node.prototype.getExtraMenuOptions = function (_, options) {
if (this.imgs) { if (this.imgs) {
// If this node has images then we add an open in new tab item // If this node has images then we add an open in new tab item
@ -286,16 +351,17 @@ export class ComfyApp {
content: "Open Image", content: "Open Image",
callback: () => { callback: () => {
let url = new URL(img.src); let url = new URL(img.src);
url.searchParams.delete('preview'); url.searchParams.delete("preview");
window.open(url, "_blank") window.open(url, "_blank");
}, },
}, },
...getCopyImageOption(img),
{ {
content: "Save Image", content: "Save Image",
callback: () => { callback: () => {
const a = document.createElement("a"); const a = document.createElement("a");
let url = new URL(img.src); let url = new URL(img.src);
url.searchParams.delete('preview'); url.searchParams.delete("preview");
a.href = url; a.href = url;
a.setAttribute("download", new URLSearchParams(url.search).get("filename")); a.setAttribute("download", new URLSearchParams(url.search).get("filename"));
document.body.append(a); document.body.append(a);
@ -308,33 +374,41 @@ export class ComfyApp {
} }
options.push({ options.push({
content: "Bypass", content: "Bypass",
callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); } callback: (obj) => {
}); if (this.mode === 4) this.mode = 0;
else this.mode = 4;
this.graph.change();
},
});
// prevent conflict of clipspace content // prevent conflict of clipspace content
if(!ComfyApp.clipspace_return_node) { if (!ComfyApp.clipspace_return_node) {
options.push({ options.push({
content: "Copy (Clipspace)", content: "Copy (Clipspace)",
callback: (obj) => { ComfyApp.copyToClipspace(this); } callback: (obj) => {
}); ComfyApp.copyToClipspace(this);
},
});
if(ComfyApp.clipspace != null) { if (ComfyApp.clipspace != null) {
options.push({ options.push({
content: "Paste (Clipspace)", content: "Paste (Clipspace)",
callback: () => { ComfyApp.pasteFromClipspace(this); } callback: () => {
}); ComfyApp.pasteFromClipspace(this);
},
});
} }
if(ComfyApp.isImageNode(this)) { if (ComfyApp.isImageNode(this)) {
options.push({ options.push({
content: "Open in MaskEditor", content: "Open in MaskEditor",
callback: (obj) => { callback: (obj) => {
ComfyApp.copyToClipspace(this); ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this; ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor(); ComfyApp.open_maskeditor();
} },
}); });
} }
} }
}; };
@ -1291,10 +1365,92 @@ export class ComfyApp {
await Promise.all(extensionPromises); await Promise.all(extensionPromises);
} }
async #migrateSettings() {
this.isNewUserSession = true;
// Store all current settings
const settings = Object.keys(this.ui.settings).reduce((p, n) => {
const v = localStorage[`Comfy.Settings.${n}`];
if (v) {
try {
p[n] = JSON.parse(v);
} catch (error) {}
}
return p;
}, {});
await api.storeSettings(settings);
}
async #setUser() {
const userConfig = await api.getUserConfig();
this.storageLocation = userConfig.storage;
if (typeof userConfig.migrated == "boolean") {
// Single user mode migrated true/false for if the default user is created
if (!userConfig.migrated && this.storageLocation === "server") {
// Default user not created yet
await this.#migrateSettings();
}
return;
}
this.multiUserServer = true;
let user = localStorage["Comfy.userId"];
const users = userConfig.users ?? {};
if (!user || !users[user]) {
// This will rarely be hit so move the loading to on demand
const { UserSelectionScreen } = await import("./ui/userSelection.js");
this.ui.menuContainer.style.display = "none";
const { userId, username, created } = await new UserSelectionScreen().show(users, user);
this.ui.menuContainer.style.display = "";
user = userId;
localStorage["Comfy.userName"] = username;
localStorage["Comfy.userId"] = user;
if (created) {
api.user = user;
await this.#migrateSettings();
}
}
api.user = user;
this.ui.settings.addSetting({
id: "Comfy.SwitchUser",
name: "Switch User",
type: (name) => {
let currentUser = localStorage["Comfy.userName"];
if (currentUser) {
currentUser = ` (${currentUser})`;
}
return $el("tr", [
$el("td", [
$el("label", {
textContent: name,
}),
]),
$el("td", [
$el("button", {
textContent: name + (currentUser ?? ""),
onclick: () => {
delete localStorage["Comfy.userId"];
delete localStorage["Comfy.userName"];
window.location.reload();
},
}),
]),
]);
},
});
}
/** /**
* Set up the app on the page * Set up the app on the page
*/ */
async setup() { async setup() {
await this.#setUser();
await this.ui.settings.load();
await this.#loadExtensions(); await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM // Create and mount the LiteGraph in the DOM
@ -1338,6 +1494,7 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("init"); await this.#invokeExtensionsAsync("init");
await this.registerNodes(); await this.registerNodes();
initWidgets(this);
// Load previous workflow // Load previous workflow
let restored = false; let restored = false;
@ -1692,6 +1849,14 @@ export class ComfyApp {
*/ */
async graphToPrompt() { async graphToPrompt() {
for (const outerNode of this.graph.computeExecutionOrder(false)) { for (const outerNode of this.graph.computeExecutionOrder(false)) {
if (outerNode.widgets) {
for (const widget of outerNode.widgets) {
// Allow widgets to run callbacks before a prompt has been queued
// e.g. random seed before every gen
widget.beforeQueued?.();
}
}
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) { for (const node of innerNodes) {
if (node.isVirtualNode) { if (node.isVirtualNode) {
@ -1903,6 +2068,7 @@ export class ComfyApp {
} finally { } finally {
this.#processingQueue = false; this.#processingQueue = false;
} }
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
} }
/** /**
@ -2046,6 +2212,8 @@ export class ComfyApp {
} }
} }
} }
await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
} }
/** /**

View File

@ -239,7 +239,8 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
node.flags?.collapsed || node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 || widget.computedHeight <= 0 ||
widget.type === "converted-widget"; widget.type === "converted-widget"||
widget.type === "hidden";
element.hidden = hidden; element.hidden = hidden;
element.style.display = hidden ? "none" : null; element.style.display = hidden ? "none" : null;
if (hidden) { if (hidden) {

View File

@ -269,6 +269,9 @@ export class ComfyLogging {
id: settingId, id: settingId,
name: settingId, name: settingId,
defaultValue: true, defaultValue: true,
onChange: (value) => {
this.enabled = value;
},
type: (name, setter, value) => { type: (name, setter, value) => {
return $el("tr", [ return $el("tr", [
$el("td", [ $el("td", [
@ -283,7 +286,7 @@ export class ComfyLogging {
type: "checkbox", type: "checkbox",
checked: value, checked: value,
onchange: (event) => { onchange: (event) => {
setter((this.enabled = event.target.checked)); setter(event.target.checked);
}, },
}), }),
$el("button", { $el("button", {

View File

@ -1,5 +1,23 @@
import {api} from "./api.js"; import { api } from "./api.js";
import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js";
import { toggleSwitch } from "./ui/toggleSwitch.js";
import { ComfySettingsDialog } from "./ui/settings.js";
export const ComfyDialog = _ComfyDialog;
/**
*
* @param { string } tag HTML Element Tag and optional classes e.g. div.class1.class2
* @param { string | Element | Element[] | {
* parent?: Element,
* $?: (el: Element) => void,
* dataset?: DOMStringMap,
* style?: CSSStyleDeclaration,
* for?: string
* } | undefined } propsOrChildren
* @param { Element[] | undefined } [children]
* @returns
*/
export function $el(tag, propsOrChildren, children) { export function $el(tag, propsOrChildren, children) {
const split = tag.split("."); const split = tag.split(".");
const element = document.createElement(split.shift()); const element = document.createElement(split.shift());
@ -8,6 +26,11 @@ export function $el(tag, propsOrChildren, children) {
} }
if (propsOrChildren) { if (propsOrChildren) {
if (typeof propsOrChildren === "string") {
propsOrChildren = { textContent: propsOrChildren };
} else if (propsOrChildren instanceof Element) {
propsOrChildren = [propsOrChildren];
}
if (Array.isArray(propsOrChildren)) { if (Array.isArray(propsOrChildren)) {
element.append(...propsOrChildren); element.append(...propsOrChildren);
} else { } else {
@ -31,7 +54,7 @@ export function $el(tag, propsOrChildren, children) {
Object.assign(element, propsOrChildren); Object.assign(element, propsOrChildren);
if (children) { if (children) {
element.append(...children); element.append(...(children instanceof Array ? children : [children]));
} }
if (parent) { if (parent) {
@ -167,267 +190,6 @@ function dragElement(dragEl, settings) {
} }
} }
export class ComfyDialog {
constructor() {
this.element = $el("div.comfy-modal", {parent: document.body}, [
$el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]),
]);
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Close",
onclick: () => this.close(),
}),
];
}
close() {
this.element.style.display = "none";
}
show(html) {
if (typeof html === "string") {
this.textElement.innerHTML = html;
} else {
this.textElement.replaceChildren(html);
}
this.element.style.display = "flex";
}
}
class ComfySettingsDialog extends ComfyDialog {
constructor() {
super();
this.element = $el("dialog", {
id: "comfy-settings-dialog",
parent: document.body,
}, [
$el("table.comfy-modal-content.comfy-table", [
$el("caption", {textContent: "Settings"}),
$el("tbody", {$: (tbody) => (this.textElement = tbody)}),
$el("button", {
type: "button",
textContent: "Close",
style: {
cursor: "pointer",
},
onclick: () => {
this.element.close();
},
}),
]),
]);
this.settings = [];
}
getSettingValue(id, defaultValue) {
const settingId = "Comfy.Settings." + id;
const v = localStorage[settingId];
return v == null ? defaultValue : JSON.parse(v);
}
setSettingValue(id, value) {
const settingId = "Comfy.Settings." + id;
localStorage[settingId] = JSON.stringify(value);
}
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
if (!id) {
throw new Error("Settings must have an ID");
}
if (this.settings.find((s) => s.id === id)) {
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
}
const settingId = `Comfy.Settings.${id}`;
const v = localStorage[settingId];
let value = v == null ? defaultValue : JSON.parse(v);
// Trigger initial setting of value
if (onChange) {
onChange(value, undefined);
}
this.settings.push({
render: () => {
const setter = (v) => {
if (onChange) {
onChange(v, value);
}
localStorage[settingId] = JSON.stringify(v);
value = v;
};
value = this.getSettingValue(id, defaultValue);
let element;
const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name,
})
]);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
id: htmlID,
type: "checkbox",
checked: value,
onchange: (event) => {
const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked)
}
this.setSettingValue(id, isChecked);
},
}),
]),
])
break;
case "number":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
type,
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
case "slider":
element = $el("tr", [
labelCell,
$el("td", [
$el("div", {
style: {
display: "grid",
gridAutoFlow: "column",
},
}, [
$el("input", {
...attrs,
value,
type: "range",
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
}),
$el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: {maxWidth: "4rem"},
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]),
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
console.warn(`Unsupported setting type '${type}, defaulting to text`);
}
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
}
}
if (tooltip) {
element.title = tooltip;
}
return element;
},
});
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
this.textElement.replaceChildren(
$el("tr", {
style: {display: "none"},
}, [
$el("th"),
$el("th", {style: {width: "33%"}})
]),
...this.settings.map((s) => s.render()),
)
this.element.showModal();
}
}
class ComfyList { class ComfyList {
#type; #type;
#text; #text;
@ -526,7 +288,7 @@ export class ComfyUI {
constructor(app) { constructor(app) {
this.app = app; this.app = app;
this.dialog = new ComfyDialog(); this.dialog = new ComfyDialog();
this.settings = new ComfySettingsDialog(); this.settings = new ComfySettingsDialog(app);
this.batchCount = 1; this.batchCount = 1;
this.lastQueueSize = 0; this.lastQueueSize = 0;
@ -607,6 +369,31 @@ export class ComfyUI {
}, },
}); });
const autoQueueModeEl = toggleSwitch(
"autoQueueMode",
[
{ text: "instant", tooltip: "A new prompt will be queued as soon as the queue reaches 0" },
{ text: "change", tooltip: "A new prompt will be queued when the queue is at 0 and the graph is/has changed" },
],
{
onChange: (value) => {
this.autoQueueMode = value.item.value;
},
}
);
autoQueueModeEl.style.display = "none";
api.addEventListener("graphChanged", () => {
if (this.autoQueueMode === "change" && this.autoQueueEnabled === true) {
if (this.lastQueueSize === 0) {
this.graphHasChanged = false;
app.queuePrompt(0, this.batchCount);
} else {
this.graphHasChanged = true;
}
}
});
this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [ this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
$el("div.drag-handle", { $el("div.drag-handle", {
style: { style: {
@ -633,6 +420,7 @@ export class ComfyUI {
document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none"; document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1; this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
document.getElementById("autoQueueCheckbox").checked = false; document.getElementById("autoQueueCheckbox").checked = false;
this.autoQueueEnabled = false;
}, },
}), }),
]), ]),
@ -664,20 +452,22 @@ export class ComfyUI {
}, },
}), }),
]), ]),
$el("div",[ $el("div",[
$el("label",{ $el("label",{
for:"autoQueueCheckbox", for:"autoQueueCheckbox",
innerHTML: "Auto Queue" innerHTML: "Auto Queue"
// textContent: "Auto Queue"
}), }),
$el("input", { $el("input", {
id: "autoQueueCheckbox", id: "autoQueueCheckbox",
type: "checkbox", type: "checkbox",
checked: false, checked: false,
title: "Automatically queue prompt when the queue size hits 0", title: "Automatically queue prompt when the queue size hits 0",
onchange: (e) => {
this.autoQueueEnabled = e.target.checked;
autoQueueModeEl.style.display = this.autoQueueEnabled ? "" : "none";
}
}), }),
autoQueueModeEl
]) ])
]), ]),
$el("div.comfy-menu-btns", [ $el("div.comfy-menu-btns", [
@ -811,10 +601,13 @@ export class ComfyUI {
if ( if (
this.lastQueueSize != 0 && this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 && status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked && this.autoQueueEnabled &&
! app.lastExecutionError (this.autoQueueMode === "instant" || this.graphHasChanged) &&
!app.lastExecutionError
) { ) {
app.queuePrompt(0, this.batchCount); app.queuePrompt(0, this.batchCount);
status.exec_info.queue_remaining += this.batchCount;
this.graphHasChanged = false;
} }
this.lastQueueSize = status.exec_info.queue_remaining; this.lastQueueSize = status.exec_info.queue_remaining;
} }

View File

@ -0,0 +1,32 @@
import { $el } from "../ui.js";
export class ComfyDialog {
constructor() {
this.element = $el("div.comfy-modal", { parent: document.body }, [
$el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]),
]);
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Close",
onclick: () => this.close(),
}),
];
}
close() {
this.element.style.display = "none";
}
show(html) {
if (typeof html === "string") {
this.textElement.innerHTML = html;
} else {
this.textElement.replaceChildren(html);
}
this.element.style.display = "flex";
}
}

View File

@ -0,0 +1,287 @@
// @ts-check
/*
Original implementation:
https://github.com/TahaSh/drag-to-reorder
MIT License
Copyright (c) 2023 Taha Shashtari
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
import { $el } from "../ui.js";
$el("style", {
parent: document.head,
textContent: `
.draggable-item {
position: relative;
will-change: transform;
user-select: none;
}
.draggable-item.is-idle {
transition: 0.25s ease transform;
}
.draggable-item.is-draggable {
z-index: 10;
}
`
});
export class DraggableList extends EventTarget {
listContainer;
draggableItem;
pointerStartX;
pointerStartY;
scrollYMax;
itemsGap = 0;
items = [];
itemSelector;
handleClass = "drag-handle";
off = [];
offDrag = [];
constructor(element, itemSelector) {
super();
this.listContainer = element;
this.itemSelector = itemSelector;
if (!this.listContainer) return;
this.off.push(this.on(this.listContainer, "mousedown", this.dragStart));
this.off.push(this.on(this.listContainer, "touchstart", this.dragStart));
this.off.push(this.on(document, "mouseup", this.dragEnd));
this.off.push(this.on(document, "touchend", this.dragEnd));
}
getAllItems() {
if (!this.items?.length) {
this.items = Array.from(this.listContainer.querySelectorAll(this.itemSelector));
this.items.forEach((element) => {
element.classList.add("is-idle");
});
}
return this.items;
}
getIdleItems() {
return this.getAllItems().filter((item) => item.classList.contains("is-idle"));
}
isItemAbove(item) {
return item.hasAttribute("data-is-above");
}
isItemToggled(item) {
return item.hasAttribute("data-is-toggled");
}
on(source, event, listener, options) {
listener = listener.bind(this);
source.addEventListener(event, listener, options);
return () => source.removeEventListener(event, listener);
}
dragStart(e) {
if (e.target.classList.contains(this.handleClass)) {
this.draggableItem = e.target.closest(this.itemSelector);
}
if (!this.draggableItem) return;
this.pointerStartX = e.clientX || e.touches[0].clientX;
this.pointerStartY = e.clientY || e.touches[0].clientY;
this.scrollYMax = this.listContainer.scrollHeight - this.listContainer.clientHeight;
this.setItemsGap();
this.initDraggableItem();
this.initItemsState();
this.offDrag.push(this.on(document, "mousemove", this.drag));
this.offDrag.push(this.on(document, "touchmove", this.drag, { passive: false }));
this.dispatchEvent(
new CustomEvent("dragstart", {
detail: { element: this.draggableItem, position: this.getAllItems().indexOf(this.draggableItem) },
})
);
}
setItemsGap() {
if (this.getIdleItems().length <= 1) {
this.itemsGap = 0;
return;
}
const item1 = this.getIdleItems()[0];
const item2 = this.getIdleItems()[1];
const item1Rect = item1.getBoundingClientRect();
const item2Rect = item2.getBoundingClientRect();
this.itemsGap = Math.abs(item1Rect.bottom - item2Rect.top);
}
initItemsState() {
this.getIdleItems().forEach((item, i) => {
if (this.getAllItems().indexOf(this.draggableItem) > i) {
item.dataset.isAbove = "";
}
});
}
initDraggableItem() {
this.draggableItem.classList.remove("is-idle");
this.draggableItem.classList.add("is-draggable");
}
drag(e) {
if (!this.draggableItem) return;
e.preventDefault();
const clientX = e.clientX || e.touches[0].clientX;
const clientY = e.clientY || e.touches[0].clientY;
const listRect = this.listContainer.getBoundingClientRect();
if (clientY > listRect.bottom) {
if (this.listContainer.scrollTop < this.scrollYMax) {
this.listContainer.scrollBy(0, 10);
this.pointerStartY -= 10;
}
} else if (clientY < listRect.top && this.listContainer.scrollTop > 0) {
this.pointerStartY += 10;
this.listContainer.scrollBy(0, -10);
}
const pointerOffsetX = clientX - this.pointerStartX;
const pointerOffsetY = clientY - this.pointerStartY;
this.updateIdleItemsStateAndPosition();
this.draggableItem.style.transform = `translate(${pointerOffsetX}px, ${pointerOffsetY}px)`;
}
updateIdleItemsStateAndPosition() {
const draggableItemRect = this.draggableItem.getBoundingClientRect();
const draggableItemY = draggableItemRect.top + draggableItemRect.height / 2;
// Update state
this.getIdleItems().forEach((item) => {
const itemRect = item.getBoundingClientRect();
const itemY = itemRect.top + itemRect.height / 2;
if (this.isItemAbove(item)) {
if (draggableItemY <= itemY) {
item.dataset.isToggled = "";
} else {
delete item.dataset.isToggled;
}
} else {
if (draggableItemY >= itemY) {
item.dataset.isToggled = "";
} else {
delete item.dataset.isToggled;
}
}
});
// Update position
this.getIdleItems().forEach((item) => {
if (this.isItemToggled(item)) {
const direction = this.isItemAbove(item) ? 1 : -1;
item.style.transform = `translateY(${direction * (draggableItemRect.height + this.itemsGap)}px)`;
} else {
item.style.transform = "";
}
});
}
dragEnd() {
if (!this.draggableItem) return;
this.applyNewItemsOrder();
this.cleanup();
}
applyNewItemsOrder() {
const reorderedItems = [];
let oldPosition = -1;
this.getAllItems().forEach((item, index) => {
if (item === this.draggableItem) {
oldPosition = index;
return;
}
if (!this.isItemToggled(item)) {
reorderedItems[index] = item;
return;
}
const newIndex = this.isItemAbove(item) ? index + 1 : index - 1;
reorderedItems[newIndex] = item;
});
for (let index = 0; index < this.getAllItems().length; index++) {
const item = reorderedItems[index];
if (typeof item === "undefined") {
reorderedItems[index] = this.draggableItem;
}
}
reorderedItems.forEach((item) => {
this.listContainer.appendChild(item);
});
this.items = reorderedItems;
this.dispatchEvent(
new CustomEvent("dragend", {
detail: { element: this.draggableItem, oldPosition, newPosition: reorderedItems.indexOf(this.draggableItem) },
})
);
}
cleanup() {
this.itemsGap = 0;
this.items = [];
this.unsetDraggableItem();
this.unsetItemState();
this.offDrag.forEach((f) => f());
this.offDrag = [];
}
unsetDraggableItem() {
this.draggableItem.style = null;
this.draggableItem.classList.remove("is-draggable");
this.draggableItem.classList.add("is-idle");
this.draggableItem = null;
}
unsetItemState() {
this.getIdleItems().forEach((item, i) => {
delete item.dataset.isAbove;
delete item.dataset.isToggled;
item.style.transform = "";
});
}
dispose() {
this.off.forEach((f) => f());
}
}

View File

@ -0,0 +1,307 @@
import { $el } from "../ui.js";
import { api } from "../api.js";
import { ComfyDialog } from "./dialog.js";
export class ComfySettingsDialog extends ComfyDialog {
constructor(app) {
super();
this.app = app;
this.settingsValues = {};
this.settingsLookup = {};
this.element = $el(
"dialog",
{
id: "comfy-settings-dialog",
parent: document.body,
},
[
$el("table.comfy-modal-content.comfy-table", [
$el("caption", { textContent: "Settings" }),
$el("tbody", { $: (tbody) => (this.textElement = tbody) }),
$el("button", {
type: "button",
textContent: "Close",
style: {
cursor: "pointer",
},
onclick: () => {
this.element.close();
},
}),
]),
]
);
}
get settings() {
return Object.values(this.settingsLookup);
}
async load() {
if (this.app.storageLocation === "browser") {
this.settingsValues = localStorage;
} else {
this.settingsValues = await api.getSettings();
}
// Trigger onChange for any settings added before load
for (const id in this.settingsLookup) {
this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]);
}
}
getId(id) {
if (this.app.storageLocation === "browser") {
id = "Comfy.Settings." + id;
}
return id;
}
getSettingValue(id, defaultValue) {
let value = this.settingsValues[this.getId(id)];
if(value != null) {
if(this.app.storageLocation === "browser") {
try {
value = JSON.parse(value);
} catch (error) {
}
}
}
return value ?? defaultValue;
}
async setSettingValueAsync(id, value) {
const json = JSON.stringify(value);
localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage
let oldValue = this.getSettingValue(id, undefined);
this.settingsValues[this.getId(id)] = value;
if (id in this.settingsLookup) {
this.settingsLookup[id].onChange?.(value, oldValue);
}
await api.storeSetting(id, value);
}
setSettingValue(id, value) {
this.setSettingValueAsync(id, value).catch((err) => {
alert(`Error saving setting '${id}'`);
console.error(err);
});
}
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) {
if (!id) {
throw new Error("Settings must have an ID");
}
if (id in this.settingsLookup) {
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
}
let skipOnChange = false;
let value = this.getSettingValue(id);
if (value == null) {
if (this.app.isNewUserSession) {
// Check if we have a localStorage value but not a setting value and we are a new user
const localValue = localStorage["Comfy.Settings." + id];
if (localValue) {
value = JSON.parse(localValue);
this.setSettingValue(id, value); // Store on the server
}
}
if (value == null) {
value = defaultValue;
}
}
// Trigger initial setting of value
if (!skipOnChange) {
onChange?.(value, undefined);
}
this.settingsLookup[id] = {
id,
onChange,
name,
render: () => {
const setter = (v) => {
if (onChange) {
onChange(v, value);
}
this.setSettingValue(id, v);
value = v;
};
value = this.getSettingValue(id, defaultValue);
let element;
const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name,
}),
]);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
id: htmlID,
type: "checkbox",
checked: value,
onchange: (event) => {
const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked);
}
this.setSettingValue(id, isChecked);
},
}),
]),
]);
break;
case "number":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
type,
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
case "slider":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"div",
{
style: {
display: "grid",
gridAutoFlow: "column",
},
},
[
$el("input", {
...attrs,
value,
type: "range",
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
}),
$el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: { maxWidth: "4rem" },
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]
),
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
console.warn(`Unsupported setting type '${type}, defaulting to text`);
}
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
}
}
if (tooltip) {
element.title = tooltip;
}
return element;
},
};
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
this.textElement.replaceChildren(
$el(
"tr",
{
style: { display: "none" },
},
[$el("th"), $el("th", { style: { width: "33%" } })]
),
...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render())
);
this.element.showModal();
}
}

View File

@ -0,0 +1,34 @@
.lds-ring {
display: inline-block;
position: relative;
width: 1em;
height: 1em;
}
.lds-ring div {
box-sizing: border-box;
display: block;
position: absolute;
width: 100%;
height: 100%;
border: 0.15em solid #fff;
border-radius: 50%;
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
border-color: #fff transparent transparent transparent;
}
.lds-ring div:nth-child(1) {
animation-delay: -0.45s;
}
.lds-ring div:nth-child(2) {
animation-delay: -0.3s;
}
.lds-ring div:nth-child(3) {
animation-delay: -0.15s;
}
@keyframes lds-ring {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

View File

@ -0,0 +1,9 @@
import { addStylesheet } from "../utils.js";
addStylesheet(import.meta.url);
export function createSpinner() {
const div = document.createElement("div");
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
return div.firstElementChild;
}

View File

@ -0,0 +1,60 @@
import { $el } from "../ui.js";
/**
* @typedef { { text: string, value?: string, tooltip?: string } } ToggleSwitchItem
*/
/**
* Creates a toggle switch element
* @param { string } name
* @param { Array<string | ToggleSwitchItem } items
* @param { Object } [opts]
* @param { (e: { item: ToggleSwitchItem, prev?: ToggleSwitchItem }) => void } [opts.onChange]
*/
export function toggleSwitch(name, items, { onChange } = {}) {
let selectedIndex;
let elements;
function updateSelected(index) {
if (selectedIndex != null) {
elements[selectedIndex].classList.remove("comfy-toggle-selected");
}
onChange?.({ item: items[index], prev: selectedIndex == null ? undefined : items[selectedIndex] });
selectedIndex = index;
elements[selectedIndex].classList.add("comfy-toggle-selected");
}
elements = items.map((item, i) => {
if (typeof item === "string") item = { text: item };
if (!item.value) item.value = item.text;
const toggle = $el(
"label",
{
textContent: item.text,
title: item.tooltip ?? "",
},
$el("input", {
name,
type: "radio",
value: item.value ?? item.text,
checked: item.selected,
onchange: () => {
updateSelected(i);
},
})
);
if (item.selected) {
updateSelected(i);
}
return toggle;
});
const container = $el("div.comfy-toggle-switch", elements);
if (selectedIndex == null) {
elements[0].children[0].checked = true;
updateSelected(0);
}
return container;
}

View File

@ -0,0 +1,135 @@
.comfy-user-selection {
width: 100vw;
height: 100vh;
position: absolute;
top: 0;
left: 0;
z-index: 999;
display: flex;
align-items: center;
justify-content: center;
font-family: sans-serif;
background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color));
}
.comfy-user-selection-inner {
background: var(--comfy-menu-bg);
margin-top: -30vh;
padding: 20px 40px;
border-radius: 10px;
min-width: 365px;
position: relative;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.3);
}
.comfy-user-selection-inner form {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
.comfy-user-selection-inner h1 {
margin: 10px 0 30px 0;
font-weight: normal;
}
.comfy-user-selection-inner label {
display: flex;
flex-direction: column;
width: 100%;
}
.comfy-user-selection input,
.comfy-user-selection select {
background-color: var(--comfy-input-bg);
color: var(--input-text);
border: 0;
border-radius: 5px;
padding: 5px;
margin-top: 10px;
}
.comfy-user-selection input::placeholder {
color: var(--descrip-text);
opacity: 1;
}
.comfy-user-existing {
width: 100%;
}
.no-users .comfy-user-existing {
display: none;
}
.comfy-user-selection-inner .or-separator {
margin: 10px 0;
padding: 10px;
display: block;
text-align: center;
width: 100%;
color: var(--descrip-text);
}
.comfy-user-selection-inner .or-separator {
overflow: hidden;
text-align: center;
margin-left: -10px;
}
.comfy-user-selection-inner .or-separator::before,
.comfy-user-selection-inner .or-separator::after {
content: "";
background-color: var(--border-color);
position: relative;
height: 1px;
vertical-align: middle;
display: inline-block;
width: calc(50% - 20px);
top: -1px;
}
.comfy-user-selection-inner .or-separator::before {
right: 10px;
margin-left: -50%;
}
.comfy-user-selection-inner .or-separator::after {
left: 10px;
margin-right: -50%;
}
.comfy-user-selection-inner section {
width: 100%;
padding: 10px;
margin: -10px;
transition: background-color 0.2s;
}
.comfy-user-selection-inner section.selected {
background: var(--border-color);
border-radius: 5px;
}
.comfy-user-selection-inner footer {
display: flex;
flex-direction: column;
align-items: center;
margin-top: 20px;
}
.comfy-user-selection-inner .comfy-user-error {
color: var(--error-text);
margin-bottom: 10px;
}
.comfy-user-button-next {
font-size: 16px;
padding: 6px 10px;
width: 100px;
display: flex;
gap: 5px;
align-items: center;
justify-content: center;
}

View File

@ -0,0 +1,114 @@
import { api } from "../api.js";
import { $el } from "../ui.js";
import { addStylesheet } from "../utils.js";
import { createSpinner } from "./spinner.js";
export class UserSelectionScreen {
async show(users, user) {
// This will rarely be hit so move the loading to on demand
await addStylesheet(import.meta.url);
const userSelection = document.getElementById("comfy-user-selection");
userSelection.style.display = "";
return new Promise((resolve) => {
const input = userSelection.getElementsByTagName("input")[0];
const select = userSelection.getElementsByTagName("select")[0];
const inputSection = input.closest("section");
const selectSection = select.closest("section");
const form = userSelection.getElementsByTagName("form")[0];
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
const button = userSelection.getElementsByClassName("comfy-user-button-next")[0];
let inputActive = null;
input.addEventListener("focus", () => {
inputSection.classList.add("selected");
selectSection.classList.remove("selected");
inputActive = true;
});
select.addEventListener("focus", () => {
inputSection.classList.remove("selected");
selectSection.classList.add("selected");
inputActive = false;
select.style.color = "";
});
select.addEventListener("blur", () => {
if (!select.value) {
select.style.color = "var(--descrip-text)";
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (inputActive == null) {
error.textContent = "Please enter a username or select an existing user.";
} else if (inputActive) {
const username = input.value.trim();
if (!username) {
error.textContent = "Please enter a username.";
return;
}
// Create new user
input.disabled = select.disabled = input.readonly = select.readonly = true;
const spinner = createSpinner();
button.prepend(spinner);
try {
const resp = await api.createUser(username);
if (resp.status >= 300) {
let message = "Error creating user: " + resp.status + " " + resp.statusText;
try {
const res = await resp.json();
if(res.error) {
message = res.error;
}
} catch (error) {
}
throw new Error(message);
}
resolve({ username, userId: await resp.json(), created: true });
} catch (err) {
spinner.remove();
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
input.disabled = select.disabled = input.readonly = select.readonly = false;
return;
}
} else if (!select.value) {
error.textContent = "Please select an existing user.";
return;
} else {
resolve({ username: users[select.value], userId: select.value, created: false });
}
});
if (user) {
const name = localStorage["Comfy.userName"];
if (name) {
input.value = name;
}
}
if (input.value) {
// Focus the input, do this separately as sometimes browsers like to fill in the value
input.focus();
}
const userIds = Object.keys(users ?? {});
if (userIds.length) {
for (const u of userIds) {
$el("option", { textContent: users[u], value: u, parent: select });
}
select.style.color = "var(--descrip-text)";
if (select.value) {
// Focus the select, do this separately as sometimes browsers like to fill in the value
select.focus();
}
} else {
userSelection.classList.add("no-users");
input.focus();
}
}).then((r) => {
userSelection.remove();
return r;
});
}
}

View File

@ -1,3 +1,5 @@
import { $el } from "./ui.js";
// Simple date formatter // Simple date formatter
const parts = { const parts = {
d: (d) => d.getDate(), d: (d) => d.getDate(),
@ -65,3 +67,22 @@ export function applyTextReplacements(app, value) {
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
}); });
} }
export async function addStylesheet(urlOrFile, relativeTo) {
return new Promise((res, rej) => {
let url;
if (urlOrFile.endsWith(".js")) {
url = urlOrFile.substr(0, urlOrFile.length - 2) + "css";
} else {
url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString();
}
$el("link", {
parent: document.head,
rel: "stylesheet",
type: "text/css",
href: url,
onload: res,
onerror: rej,
});
});
}

View File

@ -1,6 +1,19 @@
import { api } from "./api.js" import { api } from "./api.js"
import "./domWidget.js"; import "./domWidget.js";
let controlValueRunBefore = false;
export function updateControlWidgetLabel(widget) {
let replacement = "after";
let find = "before";
if (controlValueRunBefore) {
[find, replacement] = [replacement, find]
}
widget.label = (widget.label ?? widget.name).replace(find, replacement);
}
const IS_CONTROL_WIDGET = Symbol();
const HAS_EXECUTED = Symbol();
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"]; let defaultVal = inputData[1]["default"];
let { min, max, step, round} = inputData[1]; let { min, max, step, round} = inputData[1];
@ -62,6 +75,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
serialize: false, // Don't include this in prompt. serialize: false, // Don't include this in prompt.
} }
); );
valueControl[IS_CONTROL_WIDGET] = true;
updateControlWidgetLabel(valueControl);
widgets.push(valueControl); widgets.push(valueControl);
const isCombo = targetWidget.type === "combo"; const isCombo = targetWidget.type === "combo";
@ -76,10 +91,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
serialize: false, // Don't include this in prompt. serialize: false, // Don't include this in prompt.
} }
); );
updateControlWidgetLabel(comboFilter);
widgets.push(comboFilter); widgets.push(comboFilter);
} }
valueControl.afterQueued = () => { const applyWidgetControl = () => {
var v = valueControl.value; var v = valueControl.value;
if (isCombo && v !== "fixed") { if (isCombo && v !== "fixed") {
@ -159,6 +176,23 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
targetWidget.callback(targetWidget.value); targetWidget.callback(targetWidget.value);
} }
}; };
valueControl.beforeQueued = () => {
if (controlValueRunBefore) {
// Don't run on first execution
if (valueControl[HAS_EXECUTED]) {
applyWidgetControl();
}
}
valueControl[HAS_EXECUTED] = true;
};
valueControl.afterQueued = () => {
if (!controlValueRunBefore) {
applyWidgetControl();
}
};
return widgets; return widgets;
}; };
@ -224,6 +258,34 @@ function isSlider(display, app) {
return (display==="slider") ? "slider" : "number" return (display==="slider") ? "slider" : "number"
} }
export function initWidgets(app) {
app.ui.settings.addSetting({
id: "Comfy.WidgetControlMode",
name: "Widget Value Control Mode",
type: "combo",
defaultValue: "after",
options: ["before", "after"],
tooltip: "Controls when widget values are updated (randomize/increment/decrement), either before the prompt is queued or after.",
onChange(value) {
controlValueRunBefore = value === "before";
for (const n of app.graph._nodes) {
if (!n.widgets) continue;
for (const w of n.widgets) {
if (w[IS_CONTROL_WIDGET]) {
updateControlWidgetLabel(w);
if (w.linkedWidgets) {
for (const l of w.linkedWidgets) {
updateControlWidgetLabel(l);
}
}
}
}
}
app.graph.setDirtyCanvas(true);
},
});
}
export const ComfyWidgets = { export const ComfyWidgets = {
"INT:seed": seedWidget, "INT:seed": seedWidget,
"INT:noise_seed": seedWidget, "INT:noise_seed": seedWidget,

View File

@ -121,6 +121,8 @@ body {
width: 100%; width: 100%;
} }
.comfy-toggle-switch,
.comfy-btn,
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button, .comfy-menu .comfy-list button,
@ -133,6 +135,7 @@ body {
margin-top: 2px; margin-top: 2px;
} }
.comfy-btn:hover:not(:disabled),
.comfy-menu > button:hover, .comfy-menu > button:hover,
.comfy-menu-btns button:hover, .comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover, .comfy-menu .comfy-list button:hover,
@ -143,6 +146,12 @@ body {
} }
.comfy-menu span.drag-handle { .comfy-menu span.drag-handle {
position: absolute;
top: 0;
left: 0;
}
span.drag-handle {
width: 10px; width: 10px;
height: 20px; height: 20px;
display: inline-block; display: inline-block;
@ -158,12 +167,9 @@ body {
letter-spacing: 2px; letter-spacing: 2px;
color: var(--drag-text); color: var(--drag-text);
text-shadow: 1px 0 1px black; text-shadow: 1px 0 1px black;
position: absolute;
top: 0;
left: 0;
} }
.comfy-menu span.drag-handle::after { span.drag-handle::after {
content: '.. .. ..'; content: '.. .. ..';
} }
@ -429,6 +435,43 @@ dialog::backdrop {
margin-left: 5px; margin-left: 5px;
} }
.comfy-toggle-switch {
border-width: 2px;
display: flex;
background-color: var(--comfy-input-bg);
margin: 2px 0;
white-space: nowrap;
}
.comfy-toggle-switch label {
padding: 2px 0px 3px 6px;
flex: auto;
border-radius: 8px;
align-items: center;
display: flex;
justify-content: center;
}
.comfy-toggle-switch label:first-child {
border-top-left-radius: 8px;
border-bottom-left-radius: 8px;
}
.comfy-toggle-switch label:last-child {
border-top-right-radius: 8px;
border-bottom-right-radius: 8px;
}
.comfy-toggle-switch .comfy-toggle-selected {
background-color: var(--comfy-menu-bg);
}
#extraOptions {
padding: 4px;
background-color: var(--bg-color);
margin-bottom: 4px;
border-radius: 4px;
}
/* Search box */ /* Search box */
.litegraph.litesearchbox { .litegraph.litesearchbox {

View File

@ -1,4 +1,5 @@
from comfy import samplers from comfy import samplers
from comfy import model_management
from comfy import sample from comfy import sample
from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview from comfy.cmd import latent_preview
@ -26,9 +27,8 @@ class BasicScheduler:
if denoise < 1.0: if denoise < 1.0:
total_steps = int(steps/denoise) total_steps = int(steps/denoise)
inner_model = model.patch_model(patch_weights=False) model_management.load_models_gpu([model])
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu() sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
model.unpatch_model()
sigmas = sigmas[-(steps + 1):] sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -106,9 +106,8 @@ class SDTurboScheduler:
def get_sigmas(self, model, steps, denoise): def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise) start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
inner_model = model.patch_model(patch_weights=False) model_management.load_models_gpu([model])
sigmas = inner_model.model_sampling.sigma(timesteps) sigmas = model.model.model_sampling.sigma(timesteps)
model.unpatch_model()
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, ) return (sigmas, )

View File

@ -34,7 +34,7 @@ class FreeU:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
@ -73,7 +73,7 @@ class FreeU_V2:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]

View File

@ -32,29 +32,29 @@ class HyperTile:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None self.temp = None
def hypertile_in(q, k, v, extra_options): def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to: model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
for i in range(max_depth + 1):
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
if model_chans in apply_to:
shape = extra_options["original_shape"] shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2] aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1) hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size) nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size) nw = random_divisor(w, latent_tile_size * factor, swap_size)

View File

@ -122,10 +122,34 @@ class LatentBatch:
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,) return (samples_out,)
class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
if 'batch_index' in samples_out:
samples_out.pop('batch_index')
elif seed_behavior == "fixed":
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate, "LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch, "LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
} }

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import scipy.ndimage import scipy.ndimage
import torch import torch
import comfy.utils from comfy import utils
from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.common import MAX_RESOLUTION
@ -11,7 +11,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if resize_source: if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) source = utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
@ -24,7 +24,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
else: else:
mask = mask.to(destination.device, copy=True) mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) mask = utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination # calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds # this prevents the source trying to overwrite latent pixels that are out of bounds

View File

@ -1,5 +1,5 @@
import torch import torch
import comfy.utils from comfy import utils
class PatchModelAddDownscale: class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@ -27,12 +27,12 @@ class PatchModelAddDownscale:
if transformer_options["block"][1] == block_number: if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item() sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end: if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") h = utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h return h
def output_block_patch(h, hsp, transformer_options): def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]: if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") h = utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp return h, hsp
m = model.clone() m = model.clone()

View File

@ -1,6 +1,6 @@
from comfy import sd from comfy import sd, utils
from comfy import model_base from comfy import model_base
import comfy.model_management from comfy import model_management
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
import json import json
@ -118,6 +118,48 @@ class ModelMergeBlocks:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
class CheckpointSave: class CheckpointSave:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -136,46 +178,7 @@ class CheckpointSave:
CATEGORY = "advanced/model_merging" CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {} return {}
class CLIPSave: class CLIPSave:
@ -205,7 +208,7 @@ class CLIPSave:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x]) metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()]) model_management.load_models_gpu([clip.load_model()])
clip_sd = clip.get_sd() clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]: for prefix in ["clip_l.", "clip_g.", ""]:
@ -229,9 +232,9 @@ class CLIPSave:
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) current_clip_sd = utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {} return {}
class VAESave: class VAESave:
@ -265,7 +268,7 @@ class VAESave:
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {} return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from comfy.cmd import folder_paths
from comfy import clip_model
from comfy import clip_vision
from comfy import ops
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
VISION_CONFIG_DICT = {
"hidden_size": 1024,
"image_size": 224,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"hidden_act": "quick_gelu",
}
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=ops):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = operations.LayerNorm(in_dim)
self.fc1 = operations.Linear(in_dim, hidden_dim)
self.fc2 = operations.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class FuseModule(nn.Module):
def __init__(self, embed_dim, operations):
super().__init__()
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim)
def fuse_fn(self, prompt_embeds, id_embeds):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
def forward(
self,
prompt_embeds,
id_embeds,
class_tokens_mask,
) -> torch.Tensor:
# id_embeds shape: [b, max_num_inputs, 1, 2048]
id_embeds = id_embeds.to(prompt_embeds.dtype)
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
batch_size, max_num_inputs = id_embeds.shape[:2]
# seq_length: 77
seq_length = prompt_embeds.shape[1]
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
flat_id_embeds = id_embeds.view(
-1, id_embeds.shape[-2], id_embeds.shape[-1]
)
# valid_id_mask [b*max_num_inputs]
valid_id_mask = (
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
< num_inputs[:, None]
)
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
class_tokens_mask = class_tokens_mask.view(-1)
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
# slice out the image token embeddings
image_token_embeds = prompt_embeds[class_tokens_mask]
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
class PhotoMakerIDEncoder(clip_model.CLIPVisionModelProjection):
def __init__(self):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
dtype = comfy.model_management.text_encoder_dtype(self.load_device)
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, ops.manual_cast)
self.visual_projection_2 = ops.manual_cast.Linear(1024, 1280, bias=False)
self.fuse_module = FuseModule(2048, ops.manual_cast)
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[2]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
RETURN_TYPES = ("PHOTOMAKER",)
FUNCTION = "load_photomaker_model"
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
data = data["id_encoder"]
photomaker_model.load_state_dict(data)
return (photomaker_model,)
class PhotoMakerEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",),
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_photomaker"
CATEGORY = "_for_testing/photomaker"
def apply_photomaker(self, photomaker, image, clip, text):
special_token = "photomaker"
pixel_values = clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
try:
index = text.split(" ").index(special_token) + 1
except ValueError:
index = -1
tokens = clip.tokenize(text, return_word_ids=True)
out_tokens = {}
for k in tokens:
out_tokens[k] = []
for t in tokens[k]:
f = list(filter(lambda x: x[2] != index, t))
while len(f) < len(t):
f.append(t[-1])
out_tokens[k].append(f)
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
if index > 0:
token_index = index - 1
num_id_images = 1
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
else:
out = cond
return ([[out, {"pooled_output": pooled}]], )
NODE_CLASS_MAPPINGS = {
"PhotoMakerLoader": PhotoMakerLoader,
"PhotoMakerEncode": PhotoMakerEncode,
}

View File

@ -33,6 +33,7 @@ class Blend:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image2 = image2.to(image1.device)
if image1.shape != image2.shape: if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2) image2 = image2.permute(0, 3, 1, 2)
image2 = utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')

View File

@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2) attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool # Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = math.ceil(math.sqrt(lh * lw / hw1)) ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape # Reshape
@ -143,6 +143,8 @@ class SelfAttentionGuidance:
sigma = args["sigma"] sigma = args["sigma"]
model_options = args["model_options"] model_options = args["model_options"]
x = args["input"] x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image # create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)

View File

@ -1,5 +1,4 @@
import torch import torch
import comfy.utils
from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.common import MAX_RESOLUTION
from comfy import utils from comfy import utils
@ -49,13 +48,57 @@ class StableZero123_Conditioning:
encode_pixels = pixels[:,:,:,:3] encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth) cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]] positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent}) return (positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = []
for i in range(batch_size):
cam_embeds.append(camera_embeddings(elevation, azimuth))
elevation += elevation_batch_increment
azimuth += azimuth_batch_increment
cam_embeds = torch.cat(cam_embeds, dim=0)
cond = torch.cat([utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
} }

View File

@ -3,6 +3,7 @@ import torch
import comfy.utils import comfy.utils
import comfy.sd import comfy.sd
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from . import nodes_model_merging
class ImageOnlyCheckpointLoader: class ImageOnlyCheckpointLoader:
@ -78,10 +79,26 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg) m.set_model_sampler_cfg_function(linear_cfg)
return (m, ) return (m, )
class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -1,3 +1,4 @@
{ {
"presets": ["@babel/preset-env"] "presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
} }

View File

@ -11,6 +11,7 @@
"devDependencies": { "devDependencies": {
"@babel/preset-env": "^7.22.20", "@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5", "@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0", "jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0" "jest-environment-jsdom": "^29.7.0"
} }
@ -2591,6 +2592,19 @@
"@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0"
} }
}, },
"node_modules/babel-plugin-transform-import-meta": {
"version": "2.2.1",
"resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz",
"integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==",
"dev": true,
"dependencies": {
"@babel/template": "^7.4.4",
"tslib": "^2.4.0"
},
"peerDependencies": {
"@babel/core": "^7.10.0"
}
},
"node_modules/babel-preset-current-node-syntax": { "node_modules/babel-preset-current-node-syntax": {
"version": "1.0.1", "version": "1.0.1",
"resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz",
@ -5233,6 +5247,12 @@
"node": ">=12" "node": ">=12"
} }
}, },
"node_modules/tslib": {
"version": "2.6.2",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==",
"dev": true
},
"node_modules/type-detect": { "node_modules/type-detect": {
"version": "4.0.8", "version": "4.0.8",
"resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz",

View File

@ -24,6 +24,7 @@
"devDependencies": { "devDependencies": {
"@babel/preset-env": "^7.22.20", "@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5", "@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0", "jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0" "jest-environment-jsdom": "^29.7.0"
} }

View File

@ -0,0 +1,295 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("users", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
function expectNoUserScreen() {
// Ensure login isnt visible
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection["style"].display).toBe("none");
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
}
describe("multi-user", () => {
function mockAddStylesheet() {
const utils = require("../../comfy/web/scripts/utils");
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
}
async function waitForUserScreenShow() {
mockAddStylesheet();
// Wait for "show" to be called
const { UserSelectionScreen } = require("../../comfy/web/scripts/ui/userSelection");
let resolve, reject;
const fn = UserSelectionScreen.prototype.show;
const p = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
const res = fn(...args);
await new Promise(process.nextTick); // wait for promises to resolve
resolve();
return res;
});
// @ts-ignore
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
await p;
await new Promise(process.nextTick); // wait for promises to resolve
}
async function testUserScreen(onShown, users) {
if (!users) {
users = {};
}
const starting = start({
resetEnv: true,
userConfig: { storage: "server", users },
});
// Ensure no current user
expect(localStorage["Comfy.userId"]).toBeFalsy();
expect(localStorage["Comfy.userName"]).toBeFalsy();
await waitForUserScreenShow();
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection).toBeTruthy();
// Ensure login is visible
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
// Ensure menu is hidden
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).toBe("none");
const isCreate = await onShown(selection);
// Submit form
selection.querySelectorAll("form")[0].submit();
await new Promise(process.nextTick); // wait for promises to resolve
// Wait for start
const s = await starting;
// Ensure login is removed
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
// Ensure settings + templates are saved
const { api } = require("../../comfy/web/scripts/api");
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
if (isCreate) {
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(s.app.isNewUserSession).toBeTruthy();
} else {
expect(s.app.isNewUserSession).toBeFalsy();
}
return { users, selection, ...s };
}
it("allows user creation if no users", async () => {
const { users } = await testUserScreen((selection) => {
// Ensure we have no users flag added
expect(selection.classList.contains("no-users")).toBeTruthy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User";
return true;
});
expect(users).toStrictEqual({
"Test User!": "Test User",
});
expect(localStorage["Comfy.userId"]).toBe("Test User!");
expect(localStorage["Comfy.userName"]).toBe("Test User");
});
it("allows user creation if no current user but other users", async () => {
const users = {
"Test User 2!": "Test User 2",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User 3";
return true;
}, users);
expect(users).toStrictEqual({
"Test User 2!": "Test User 2",
"Test User 3!": "Test User 3",
});
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
});
it("allows user selection if no current user but other users", async () => {
const users = {
"A!": "A",
"B!": "B",
"C!": "C",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Check user list
const select = selection.getElementsByTagName("select")[0];
const options = select.getElementsByTagName("option");
expect(
[...options]
.filter((o) => !o.disabled)
.reduce((p, n) => {
p[n.getAttribute("value")] = n.textContent;
return p;
}, {})
).toStrictEqual(users);
// Select an option
select.focus();
select.value = options[2].value;
return false;
}, users);
expect(users).toStrictEqual(users);
expect(localStorage["Comfy.userId"]).toBe("B!");
expect(localStorage["Comfy.userName"]).toBe("B");
});
it("doesnt show user screen if current user", async () => {
const starting = start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
await new Promise(process.nextTick); // wait for promises to resolve
expectNoUserScreen();
await starting;
});
it("allows user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
// cant actually test switching user easily but can check the setting is present
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
});
});
describe("single-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(app.isNewUserSession).toBeTruthy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
describe("browser-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "browser" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "browser" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
});

View File

@ -1,10 +1,18 @@
const { mockApi } = require("./setup"); const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph"); const { Ez } = require("./ezgraph");
const lg = require("./litegraph"); const lg = require("./litegraph");
const fs = require("fs");
const path = require("path");
const html = fs.readFileSync(path.resolve("../comfy/web/index.html"))
/** /**
* *
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config * @param { Parameters<typeof mockApi>[0] & {
* resetEnv?: boolean,
* preSetup?(app): Promise<void>,
* localStorage?: Record<string, string>
* } } config
* @returns * @returns
*/ */
export async function start(config = {}) { export async function start(config = {}) {
@ -12,12 +20,18 @@ export async function start(config = {}) {
jest.resetModules(); jest.resetModules();
jest.resetAllMocks(); jest.resetAllMocks();
lg.setup(global); lg.setup(global);
localStorage.clear();
sessionStorage.clear();
} }
Object.assign(localStorage, config.localStorage ?? {});
document.body.innerHTML = html;
mockApi(config); mockApi(config);
const { app } = require("../../comfy/web/scripts/app"); const { app } = require("../../comfy/web/scripts/app");
config.preSetup?.(app); config.preSetup?.(app);
await app.setup(); await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
} }

View File

@ -18,9 +18,21 @@ function* walkSync(dir) {
*/ */
/** /**
* @param { { mockExtensions?: string[], mockNodeDefs?: Record<string, ComfyObjectInfo> } } config * @param {{
* mockExtensions?: string[],
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
* settings?: Record<string, string>
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
* userData?: Record<string, any>
* }} config
*/ */
export function mockApi({ mockExtensions, mockNodeDefs } = {}) { export function mockApi(config = {}) {
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
userConfig,
settings: {},
userData: {},
...config,
};
if (!mockExtensions) { if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core"))) mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core")))
.filter((x) => x.endsWith(".js")) .filter((x) => x.endsWith(".js"))
@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
getNodeDefs: jest.fn(() => mockNodeDefs), getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(), init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x), apiURL: jest.fn((x) => "../../web/" + x),
createUser: jest.fn((username) => {
if(username in userConfig.users) {
return { status: 400, json: () => "Duplicate" }
}
userConfig.users[username + "!"] = username;
return { status: 200, json: () => username + "!" }
}),
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
getSettings: jest.fn(() => settings),
storeSettings: jest.fn((v) => Object.assign(settings, v)),
getUserData: jest.fn((f) => {
if (f in userData) {
return { status: 200, json: () => userData[f] };
} else {
return { status: 404 };
}
}),
storeUserData: jest.fn((file, data) => {
userData[file] = data;
}),
}; };
jest.mock("../../comfy/web/scripts/api", () => ({ jest.mock("../../comfy/web/scripts/api", () => ({
get api() { get api() {