mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Merge with latest upstream.
This commit is contained in:
commit
82edb2ff0e
3
.gitignore
vendored
3
.gitignore
vendored
@ -172,4 +172,5 @@ dmypy.json
|
||||
cython_debug/
|
||||
.openapi-generator/
|
||||
|
||||
/tests-ui/data/object_info.json
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
@ -101,6 +101,8 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@ -192,7 +194,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
|
||||
To run tests:
|
||||
```shell
|
||||
pytest tests/inference
|
||||
(cd tests-ui && npm ci && npm test:generate && npm test)
|
||||
(cd tests-ui && npm ci && npm run test:generate && npm test)
|
||||
```
|
||||
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
|
||||
|
||||
|
||||
0
comfy/app/__init__.py
Normal file
0
comfy/app/__init__.py
Normal file
54
comfy/app/app_settings.py
Normal file
54
comfy/app/app_settings.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import json
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
class AppSettings():
|
||||
def __init__(self, user_manager):
|
||||
self.user_manager = user_manager
|
||||
|
||||
def get_settings(self, request):
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
if os.path.isfile(file):
|
||||
with open(file) as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
return {}
|
||||
|
||||
def save_settings(self, request, settings):
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
with open(file, "w") as f:
|
||||
f.write(json.dumps(settings, indent=4))
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/settings")
|
||||
async def get_settings(request):
|
||||
return web.json_response(self.get_settings(request))
|
||||
|
||||
@routes.get("/settings/{id}")
|
||||
async def get_setting(request):
|
||||
value = None
|
||||
settings = self.get_settings(request)
|
||||
setting_id = request.match_info.get("id", None)
|
||||
if setting_id and setting_id in settings:
|
||||
value = settings[setting_id]
|
||||
return web.json_response(value)
|
||||
|
||||
@routes.post("/settings")
|
||||
async def post_settings(request):
|
||||
settings = self.get_settings(request)
|
||||
new_settings = await request.json()
|
||||
self.save_settings(request, {**settings, **new_settings})
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/settings/{id}")
|
||||
async def post_setting(request):
|
||||
setting_id = request.match_info.get("id", None)
|
||||
if not setting_id:
|
||||
return web.Response(status=400)
|
||||
settings = self.get_settings(request)
|
||||
settings[setting_id] = await request.json()
|
||||
self.save_settings(request, settings)
|
||||
return web.Response(status=200)
|
||||
135
comfy/app/user_manager.py
Normal file
135
comfy/app/user_manager.py
Normal file
@ -0,0 +1,135 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from aiohttp import web
|
||||
from ..cli_args import args
|
||||
from ..cmd.folder_paths import user_directory
|
||||
from .app_settings import AppSettings
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
self.default_user = "default"
|
||||
self.users_file = os.path.join(user_directory, "users.json")
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(self.users_file):
|
||||
with open(self.users_file) as f:
|
||||
self.users = json.load(f)
|
||||
else:
|
||||
self.users = {}
|
||||
else:
|
||||
self.users = {"default": "default"}
|
||||
|
||||
def get_request_user_id(self, request):
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
return None
|
||||
|
||||
parent = user_root
|
||||
|
||||
if file is not None:
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
return None
|
||||
|
||||
if create_dir and not os.path.exists(parent):
|
||||
os.mkdir(parent)
|
||||
|
||||
return path
|
||||
|
||||
def add_user(self, name):
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
|
||||
with open(self.users_file, "w") as f:
|
||||
json.dump(self.users, f)
|
||||
|
||||
return user_id
|
||||
|
||||
def add_routes(self, routes):
|
||||
self.settings.add_routes(routes)
|
||||
|
||||
@routes.get("/users")
|
||||
async def get_users(request):
|
||||
if args.multi_user:
|
||||
return web.json_response({"storage": "server", "users": self.users})
|
||||
else:
|
||||
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
|
||||
return web.json_response({
|
||||
"storage": "server",
|
||||
"migrated": os.path.exists(user_dir)
|
||||
})
|
||||
|
||||
@routes.post("/users")
|
||||
async def post_users(request):
|
||||
body = await request.json()
|
||||
username = body["username"]
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
user_id = self.add_user(username)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
async def getuserdata(request):
|
||||
file = request.match_info.get("file", None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
if not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
file = request.match_info.get("file", None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
body = await request.read()
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
return web.Response(status=200)
|
||||
@ -112,6 +112,8 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
if options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module):
|
||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .utils import load_torch_file, transformers_convert
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
@ -43,6 +43,9 @@ class ClipVisionModel():
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||
@ -75,6 +78,9 @@ def convert_to_transformers(sd, prefix):
|
||||
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
|
||||
|
||||
sd = transformers_convert(sd, prefix, "vision_model.", 48)
|
||||
else:
|
||||
replace_prefix = {prefix: ""}
|
||||
sd = state_dict_prefix_replace(sd, replace_prefix)
|
||||
return sd
|
||||
|
||||
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import copy
|
||||
import datetime
|
||||
import heapq
|
||||
@ -15,6 +16,7 @@ from typing import Tuple
|
||||
import sys
|
||||
import gc
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -325,11 +327,22 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.success = None
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.outputs = {}
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
|
||||
def add_message(self, event, data, broadcast: bool):
|
||||
self.status_messages.append((event, data))
|
||||
if self.server.client_id is not None or broadcast:
|
||||
self.server.send_sync(event, data, self.server.client_id)
|
||||
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
node_id = error["node_id"]
|
||||
@ -344,23 +357,22 @@ class PromptExecutor:
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
}
|
||||
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
|
||||
self.add_message("execution_interrupted", mes, broadcast=True)
|
||||
else:
|
||||
if self.server.client_id is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
}
|
||||
self.server.send_sync("execution_error", mes, self.server.client_id)
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
@ -381,8 +393,8 @@ class PromptExecutor:
|
||||
else:
|
||||
self.server.client_id = None
|
||||
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id)
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
# delete cached outputs if nodes don't exist for them
|
||||
@ -415,9 +427,9 @@ class PromptExecutor:
|
||||
del d
|
||||
|
||||
model_management.cleanup_models()
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
|
||||
self.server.client_id)
|
||||
self.add_message("execution_cached",
|
||||
{"nodes": list(current_outputs), "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
@ -434,11 +446,9 @@ class PromptExecutor:
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id,
|
||||
self.outputs_ui, self.object_storage)
|
||||
if success is not True:
|
||||
self.handle_execution_error( prompt_id,
|
||||
prompt, current_outputs, executed, error, ex)
|
||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||
if self.success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
for x in executed:
|
||||
@ -779,6 +789,7 @@ class PromptQueue:
|
||||
self.queue = []
|
||||
self.currently_running = {}
|
||||
self.history = {}
|
||||
self.flags = {}
|
||||
server.prompt_queue = self
|
||||
|
||||
def size(self) -> int:
|
||||
@ -803,15 +814,28 @@ class PromptQueue:
|
||||
self.server.queue_updated()
|
||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||
|
||||
def task_done(self, item_id, outputs: dict):
|
||||
class ExecutionStatus(NamedTuple):
|
||||
status_str: Literal['success', 'error']
|
||||
completed: bool
|
||||
messages: List[str]
|
||||
|
||||
def task_done(self, item_id, outputs: dict,
|
||||
status: Optional['PromptQueue.ExecutionStatus']):
|
||||
with self.mutex:
|
||||
queue_item = self.currently_running.pop(item_id)
|
||||
prompt = queue_item.queue_tuple
|
||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||
self.history.pop(next(iter(self.history)))
|
||||
self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": time.time()}
|
||||
for o in outputs:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
|
||||
status_dict: Optional[dict] = None
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(status._asdict())
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
"outputs": copy.deepcopy(outputs),
|
||||
'status': status_dict,
|
||||
}
|
||||
self.server.queue_updated()
|
||||
if queue_item.completed:
|
||||
queue_item.completed.set_result(outputs)
|
||||
@ -880,3 +904,17 @@ class PromptQueue:
|
||||
def delete_history_item(self, id_to_delete: int):
|
||||
with self.mutex:
|
||||
self.history.pop(id_to_delete, None)
|
||||
|
||||
def set_flag(self, name, data):
|
||||
with self.mutex:
|
||||
self.flags[name] = data
|
||||
self.not_empty.notify()
|
||||
|
||||
def get_flags(self, reset=True):
|
||||
with self.mutex:
|
||||
if reset:
|
||||
ret = self.flags
|
||||
self.flags = {}
|
||||
return ret
|
||||
else:
|
||||
return self.flags.copy()
|
||||
|
||||
@ -31,11 +31,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
|
||||
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
||||
|
||||
filename_list_cache = {}
|
||||
|
||||
@ -139,15 +142,27 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
excluded_dir_names = []
|
||||
|
||||
result = []
|
||||
dirs = {directory: os.path.getmtime(directory)}
|
||||
dirs = {}
|
||||
|
||||
# Attempt to add the initial directory to dirs with error handling
|
||||
try:
|
||||
dirs[directory] = os.path.getmtime(directory)
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Unable to access {directory}. Skipping this path.")
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
for file_name in filenames:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
|
||||
for d in subdirs:
|
||||
path = os.path.join(dirpath, d)
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
|
||||
@ -89,7 +89,7 @@ def prompt_worker(q, _server):
|
||||
gc_collect_interval = 10.0
|
||||
current_time = 0.0
|
||||
while True:
|
||||
timeout = None
|
||||
timeout = 1000.0
|
||||
if need_gc:
|
||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||
|
||||
@ -102,14 +102,32 @@ def prompt_worker(q, _server):
|
||||
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
q.task_done(item_id,
|
||||
e.outputs_ui,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages))
|
||||
if _server.client_id is not None:
|
||||
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
|
||||
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
print("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
if flags.get("unload_models", free_memory):
|
||||
model_management.unload_all_models()
|
||||
need_gc = True
|
||||
last_gc_collect = 0
|
||||
|
||||
if free_memory:
|
||||
e.reset()
|
||||
need_gc = True
|
||||
last_gc_collect = 0
|
||||
|
||||
if need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||
|
||||
@ -35,6 +35,7 @@ from ..vendor.appdirs import user_data_dir
|
||||
|
||||
nodes = import_all_nodes_in_workspace()
|
||||
|
||||
from ..app.user_manager import UserManager
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
@ -91,6 +92,7 @@ class PromptServer():
|
||||
mimetypes.init()
|
||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
@ -532,6 +534,17 @@ class PromptServer():
|
||||
model_management.interrupt_current_processing()
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/free")
|
||||
async def post_free(request):
|
||||
json_data = await request.json()
|
||||
unload_models = json_data.get("unload_models", False)
|
||||
free_memory = json_data.get("free_memory", False)
|
||||
if unload_models:
|
||||
self.prompt_queue.set_flag("unload_models", unload_models)
|
||||
if free_memory:
|
||||
self.prompt_queue.set_flag("free_memory", free_memory)
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/history")
|
||||
async def post_history(request):
|
||||
json_data = await request.json()
|
||||
@ -671,6 +684,7 @@ class PromptServer():
|
||||
return web.json_response(prompt, status=200)
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.app.add_routes(self.routes)
|
||||
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
@ -768,14 +782,9 @@ class PromptServer():
|
||||
site = web.TCPSite(runner, address, port)
|
||||
await site.start()
|
||||
|
||||
address_to_print = 'localhost'
|
||||
if address == '' or address == '0.0.0.0':
|
||||
address = '0.0.0.0'
|
||||
else:
|
||||
address_to_print = address
|
||||
if verbose:
|
||||
print("Starting server\n")
|
||||
print("To see the GUI go to: http://{}:{}".format(address_to_print, port))
|
||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import enum
|
||||
import torch
|
||||
import math
|
||||
from . import utils
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
from . import utils
|
||||
from . import model_management
|
||||
@ -127,7 +126,10 @@ class ControlBase:
|
||||
if o[i] is None:
|
||||
o[i] = prev_val
|
||||
else:
|
||||
o[i] += prev_val
|
||||
if o[i].shape[0] < prev_val.shape[0]:
|
||||
o[i] = prev_val + o[i]
|
||||
else:
|
||||
o[i] += prev_val
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
from functools import partial
|
||||
|
||||
|
||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
@ -176,6 +173,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=False,
|
||||
upcast_attention=upcast_attention,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
@ -238,6 +236,12 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
first_op_done = True
|
||||
@ -293,11 +297,14 @@ def attention_xformers(q, k, v, heads, mask=None):
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
@ -322,7 +329,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention_masked = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
@ -338,15 +344,18 @@ else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
def optimized_attention_for_device(device, mask=False):
|
||||
if device == torch.device("cpu"): #TODO
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch
|
||||
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
||||
else:
|
||||
return attention_basic
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return attention_sub_quad
|
||||
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
from .util import (
|
||||
checkpoint,
|
||||
|
||||
@ -23,13 +23,13 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
||||
return x
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
def forward(self, x, noise_level=None, seed=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
x = self.scale(x)
|
||||
z = self.q_sample(x, noise_level)
|
||||
z = self.q_sample(x, noise_level, seed=seed)
|
||||
z = self.unscale(z)
|
||||
noise_level = self.time_embed(noise_level)
|
||||
return z, noise_level
|
||||
|
||||
@ -61,6 +61,7 @@ def _summarize_chunk(
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> AttnChunk:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
@ -84,6 +85,8 @@ def _summarize_chunk(
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
attn_weights -= max_score
|
||||
if mask is not None:
|
||||
attn_weights += mask
|
||||
torch.exp(attn_weights, out=attn_weights)
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
@ -96,11 +99,12 @@ def _query_chunk_attention(
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
|
||||
key_chunk = dynamic_slice(
|
||||
key_t,
|
||||
(0, 0, chunk_idx),
|
||||
@ -111,10 +115,13 @@ def _query_chunk_attention(
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||
)
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
if mask is not None:
|
||||
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
||||
|
||||
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
|
||||
beta=0,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
attn_scores += mask
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
upcast_attention=False,
|
||||
mask = None,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||
This is efficient version of attention presented in
|
||||
@ -209,13 +220,22 @@ def efficient_dot_product_attention(
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
if mask is not None and len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return dynamic_slice(
|
||||
query,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||
)
|
||||
|
||||
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
|
||||
query=query,
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=get_mask_chunk(i * query_chunk_size)
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
||||
|
||||
@ -6,7 +6,6 @@ from . import model_management
|
||||
from . import conds
|
||||
from . import ops
|
||||
from enum import Enum
|
||||
import contextlib
|
||||
from . import utils
|
||||
|
||||
class ModelType(Enum):
|
||||
@ -100,11 +99,29 @@ class BaseModel(torch.nn.Module):
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("denoise_mask", None)
|
||||
latent_image = kwargs.get("latent_image", None)
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
concat_latent_image = kwargs.get("concat_latent_image", None)
|
||||
if concat_latent_image is None:
|
||||
concat_latent_image = kwargs.get("latent_image", None)
|
||||
else:
|
||||
concat_latent_image = self.process_latent_in(concat_latent_image)
|
||||
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if concat_latent_image.shape[1:] != noise.shape[1:]:
|
||||
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
||||
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
@ -117,9 +134,9 @@ class BaseModel(torch.nn.Module):
|
||||
for ck in concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask[:,:1].to(device))
|
||||
cond_concat.append(denoise_mask.to(device))
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
@ -159,19 +176,28 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
extra_sds = []
|
||||
if clip_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||
if vae_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||
if clip_vision_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||
|
||||
if self.get_dtype() == torch.float16:
|
||||
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||
for sd in extra_sds:
|
||||
unet_state_dict.update(sd)
|
||||
|
||||
return unet_state_dict
|
||||
|
||||
def set_inpaint(self):
|
||||
self.inpaint_model = True
|
||||
@ -190,7 +216,7 @@ class BaseModel(torch.nn.Module):
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
noise_aug = []
|
||||
@ -199,7 +225,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
|
||||
weight = unclip_cond["strength"]
|
||||
noise_augment = unclip_cond["noise_augmentation"]
|
||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
|
||||
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||
weights.append(weight)
|
||||
noise_aug.append(noise_augment)
|
||||
@ -225,11 +251,11 @@ class SD21UNCLIP(BaseModel):
|
||||
if unclip_conditioning is None:
|
||||
return torch.zeros((1, self.adm_channels))
|
||||
else:
|
||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
|
||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||
|
||||
def sdxl_pooled(args, noise_augmentor):
|
||||
if "unclip_conditioning" in args:
|
||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
|
||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
|
||||
else:
|
||||
return args["pooled_output"]
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ try:
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch.cuda.is_bf16_supported():
|
||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
if is_intel_xpu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
|
||||
@ -166,6 +166,26 @@ class ConditioningSetAreaPercentage:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class ConditioningSetAreaStrength:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['strength'] = strength
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
|
||||
class ConditioningSetMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -341,6 +361,62 @@ class VAEEncodeForInpaint:
|
||||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
|
||||
class InpaintModelConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"pixels": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae, mask):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
orig_pixels = pixels
|
||||
pixels = orig_pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
m = (1.0 - mask.round()).squeeze(1)
|
||||
for i in range(3):
|
||||
pixels[:,:,:,i] -= 0.5
|
||||
pixels[:,:,:,i] *= m
|
||||
pixels[:,:,:,i] += 0.5
|
||||
concat_latent = vae.encode(pixels)
|
||||
orig_latent = vae.encode(orig_pixels)
|
||||
|
||||
out_latent = {}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d["concat_latent_image"] = concat_latent
|
||||
d["concat_mask"] = mask
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
out.append(c)
|
||||
return (out[0], out[1], out_latent)
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -1400,6 +1476,8 @@ class LoadImage:
|
||||
output_masks = []
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
@ -1455,6 +1533,8 @@ class LoadImageMask:
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
c = channel[0].upper()
|
||||
@ -1609,10 +1689,11 @@ class ImagePadForOutpaint:
|
||||
def expand_image(self, image, left, top, right, bottom, feathering):
|
||||
d1, d2, d3, d4 = image.size()
|
||||
|
||||
new_image = torch.zeros(
|
||||
new_image = torch.ones(
|
||||
(d1, d2 + top + bottom, d3 + left + right, d4),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
) * 0.5
|
||||
|
||||
new_image[:, top:top + d2, left:left + d3, :] = image
|
||||
|
||||
mask = torch.ones(
|
||||
@ -1678,6 +1759,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
|
||||
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
|
||||
"ConditioningSetMask": ConditioningSetMask,
|
||||
"KSamplerAdvanced": KSamplerAdvanced,
|
||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||
@ -1704,6 +1786,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||
"GLIGENLoader": GLIGENLoader,
|
||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||
"InpaintModelConditioning": InpaintModelConditioning,
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
import comfy.model_management
|
||||
from . import model_management
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
non_blocking = model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
|
||||
13
comfy/sd.py
13
comfy/sd.py
@ -531,7 +531,14 @@ def load_unet(unet_path):
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
model_management.load_models_gpu([model, clip.load_model()])
|
||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
|
||||
@ -7,7 +7,6 @@ import traceback
|
||||
import zipfile
|
||||
from . import model_management
|
||||
from pkg_resources import resource_filename
|
||||
import contextlib
|
||||
from . import clip_model
|
||||
import json
|
||||
|
||||
|
||||
@ -65,6 +65,12 @@ class BASE:
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
if self.clip_vision_prefix is not None:
|
||||
replace_prefix[""] = self.clip_vision_prefix
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.diffusion_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { mergeIfValid } from "./widgetInputs.js";
|
||||
import { ManageGroupDialog } from "./groupNodeManage.js";
|
||||
|
||||
const GROUP = Symbol();
|
||||
|
||||
@ -61,11 +62,7 @@ class GroupNodeBuilder {
|
||||
);
|
||||
return;
|
||||
case Workflow.InUse.Registered:
|
||||
if (
|
||||
!confirm(
|
||||
"An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?"
|
||||
)
|
||||
) {
|
||||
if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
@ -151,6 +148,8 @@ export class GroupNodeConfig {
|
||||
this.primitiveDefs = {};
|
||||
this.widgetToPrimitive = {};
|
||||
this.primitiveToWidget = {};
|
||||
this.nodeInputs = {};
|
||||
this.outputVisibility = [];
|
||||
}
|
||||
|
||||
async registerType(source = "workflow") {
|
||||
@ -158,6 +157,7 @@ export class GroupNodeConfig {
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
output_is_hidden: [],
|
||||
name: source + "/" + this.name,
|
||||
display_name: this.name,
|
||||
category: "group nodes" + ("/" + source),
|
||||
@ -277,8 +277,7 @@ export class GroupNodeConfig {
|
||||
}
|
||||
if (input.widget) {
|
||||
const targetDef = globalDefs[node.type];
|
||||
const targetWidget =
|
||||
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
|
||||
const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
|
||||
|
||||
const widget = [targetWidget[0], config];
|
||||
const res = mergeIfValid(
|
||||
@ -330,7 +329,8 @@ export class GroupNodeConfig {
|
||||
}
|
||||
|
||||
getInputConfig(node, inputName, seenInputs, config, extra) {
|
||||
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
|
||||
const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName];
|
||||
let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
|
||||
let key = name;
|
||||
let prefix = "";
|
||||
// Special handling for primitive to include the title if it is set rather than just "value"
|
||||
@ -349,14 +349,14 @@ export class GroupNodeConfig {
|
||||
}
|
||||
if (config[0] === "IMAGEUPLOAD") {
|
||||
if (!extra) extra = {};
|
||||
extra.widget = `${prefix}${config[1]?.widget ?? "image"}`;
|
||||
extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image";
|
||||
}
|
||||
|
||||
if (extra) {
|
||||
config = [config[0], { ...config[1], ...extra }];
|
||||
}
|
||||
|
||||
return { name, config };
|
||||
return { name, config, customConfig };
|
||||
}
|
||||
|
||||
processWidgetInputs(inputs, node, inputNames, seenInputs) {
|
||||
@ -366,9 +366,7 @@ export class GroupNodeConfig {
|
||||
for (const inputName of inputNames) {
|
||||
let widgetType = app.getWidgetType(inputs[inputName], inputName);
|
||||
if (widgetType) {
|
||||
const convertedIndex = node.inputs?.findIndex(
|
||||
(inp) => inp.name === inputName && inp.widget?.name === inputName
|
||||
);
|
||||
const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName);
|
||||
if (convertedIndex > -1) {
|
||||
// This widget has been converted to a widget
|
||||
// We need to store this in the correct position so link ids line up
|
||||
@ -424,6 +422,7 @@ export class GroupNodeConfig {
|
||||
}
|
||||
|
||||
processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) {
|
||||
this.nodeInputs[node.index] = {};
|
||||
for (let i = 0; i < slots.length; i++) {
|
||||
const inputName = slots[i];
|
||||
if (linksTo[i]) {
|
||||
@ -432,7 +431,11 @@ export class GroupNodeConfig {
|
||||
continue;
|
||||
}
|
||||
|
||||
const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
|
||||
const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
|
||||
|
||||
this.nodeInputs[node.index][inputName] = name;
|
||||
if(customConfig?.visible === false) continue;
|
||||
|
||||
this.nodeDef.input.required[name] = config;
|
||||
inputMap[i] = this.inputCount++;
|
||||
}
|
||||
@ -452,6 +455,7 @@ export class GroupNodeConfig {
|
||||
const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], {
|
||||
defaultInput: true,
|
||||
});
|
||||
|
||||
this.nodeDef.input.required[name] = config;
|
||||
this.newToOldWidgetMap[name] = { node, inputName };
|
||||
|
||||
@ -477,9 +481,7 @@ export class GroupNodeConfig {
|
||||
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
|
||||
|
||||
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
|
||||
this.#convertedToProcess.push(() =>
|
||||
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
|
||||
);
|
||||
this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs));
|
||||
|
||||
return inputMapping;
|
||||
}
|
||||
@ -490,8 +492,12 @@ export class GroupNodeConfig {
|
||||
// Add outputs
|
||||
for (let outputId = 0; outputId < def.output.length; outputId++) {
|
||||
const linksFrom = this.linksFrom[node.index];
|
||||
if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) {
|
||||
// This output is linked internally so we can skip it
|
||||
// If this output is linked internally we flag it to hide
|
||||
const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId];
|
||||
const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId];
|
||||
const visible = customConfig?.visible ?? !hasLink;
|
||||
this.outputVisibility.push(visible);
|
||||
if (!visible) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -500,11 +506,15 @@ export class GroupNodeConfig {
|
||||
this.nodeDef.output.push(def.output[outputId]);
|
||||
this.nodeDef.output_is_list.push(def.output_is_list[outputId]);
|
||||
|
||||
let label = def.output_name?.[outputId] ?? def.output[outputId];
|
||||
const output = node.outputs.find((o) => o.name === label);
|
||||
if (output?.label) {
|
||||
label = output.label;
|
||||
let label = customConfig?.name;
|
||||
if (!label) {
|
||||
label = def.output_name?.[outputId] ?? def.output[outputId];
|
||||
const output = node.outputs.find((o) => o.name === label);
|
||||
if (output?.label) {
|
||||
label = output.label;
|
||||
}
|
||||
}
|
||||
|
||||
let name = label;
|
||||
if (name in seenOutputs) {
|
||||
const prefix = `${node.title ?? node.type} `;
|
||||
@ -677,6 +687,25 @@ export class GroupNodeHandler {
|
||||
return this.innerNodes;
|
||||
};
|
||||
|
||||
this.node.recreate = async () => {
|
||||
const id = this.node.id;
|
||||
const sz = this.node.size;
|
||||
const nodes = this.node.convertToNodes();
|
||||
|
||||
const groupNode = LiteGraph.createNode(this.node.type);
|
||||
groupNode.id = id;
|
||||
|
||||
// Reuse the existing nodes for this instance
|
||||
groupNode.setInnerNodes(nodes);
|
||||
groupNode[GROUP].populateWidgets();
|
||||
app.graph.add(groupNode);
|
||||
groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])];
|
||||
|
||||
// Remove all converted nodes and relink them
|
||||
groupNode[GROUP].replaceNodes(nodes);
|
||||
return groupNode;
|
||||
};
|
||||
|
||||
this.node.convertToNodes = () => {
|
||||
const addInnerNodes = () => {
|
||||
const backup = localStorage.getItem("litegrapheditor_clipboard");
|
||||
@ -769,6 +798,7 @@ export class GroupNodeHandler {
|
||||
const slot = node.inputs[groupSlotId];
|
||||
if (slot.link == null) continue;
|
||||
const link = app.graph.links[slot.link];
|
||||
if (!link) continue;
|
||||
// connect this node output to the input of another node
|
||||
const originNode = app.graph.getNodeById(link.origin_id);
|
||||
originNode.connect(link.origin_slot, newNode, +innerInputId);
|
||||
@ -806,12 +836,23 @@ export class GroupNodeHandler {
|
||||
let optionIndex = options.findIndex((o) => o.content === "Outputs");
|
||||
if (optionIndex === -1) optionIndex = options.length;
|
||||
else optionIndex++;
|
||||
options.splice(optionIndex, 0, null, {
|
||||
content: "Convert to nodes",
|
||||
callback: () => {
|
||||
return this.convertToNodes();
|
||||
options.splice(
|
||||
optionIndex,
|
||||
0,
|
||||
null,
|
||||
{
|
||||
content: "Convert to nodes",
|
||||
callback: () => {
|
||||
return this.convertToNodes();
|
||||
},
|
||||
},
|
||||
});
|
||||
{
|
||||
content: "Manage Group Node",
|
||||
callback: () => {
|
||||
new ManageGroupDialog(app).show(this.type);
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
// Draw custom collapse icon to identity this as a group
|
||||
@ -843,6 +884,7 @@ export class GroupNodeHandler {
|
||||
const r = onDrawForeground?.apply?.(this, arguments);
|
||||
if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) {
|
||||
const n = groupData.nodes[this.runningInternalNodeId];
|
||||
if(!n) return;
|
||||
const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`;
|
||||
ctx.save();
|
||||
ctx.font = "12px sans-serif";
|
||||
@ -865,6 +907,28 @@ export class GroupNodeHandler {
|
||||
return onExecutionStart?.apply(this, arguments);
|
||||
};
|
||||
|
||||
const self = this;
|
||||
const onNodeCreated = this.node.onNodeCreated;
|
||||
this.node.onNodeCreated = function () {
|
||||
const config = self.groupData.nodeData.config;
|
||||
if (config) {
|
||||
for (const n in config) {
|
||||
const inputs = config[n]?.input;
|
||||
for (const w in inputs) {
|
||||
if (inputs[w].visible !== false) continue;
|
||||
const widgetName = self.groupData.oldToNewWidgetMap[n][w];
|
||||
const widget = this.widgets.find((w) => w.name === widgetName);
|
||||
if (widget) {
|
||||
widget.type = "hidden";
|
||||
widget.computeSize = () => [0, -4];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return onNodeCreated?.apply(this, arguments);
|
||||
};
|
||||
|
||||
function handleEvent(type, getId, getEvent) {
|
||||
const handler = ({ detail }) => {
|
||||
const id = getId(detail);
|
||||
@ -902,6 +966,26 @@ export class GroupNodeHandler {
|
||||
api.removeEventListener("executing", executing);
|
||||
api.removeEventListener("executed", executed);
|
||||
};
|
||||
|
||||
this.node.refreshComboInNode = (defs) => {
|
||||
// Update combo widget options
|
||||
for (const widgetName in this.groupData.newToOldWidgetMap) {
|
||||
const widget = this.node.widgets.find((w) => w.name === widgetName);
|
||||
if (widget?.type === "combo") {
|
||||
const old = this.groupData.newToOldWidgetMap[widgetName];
|
||||
const def = defs[old.node.type];
|
||||
const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName];
|
||||
if (!input) continue;
|
||||
|
||||
widget.options.values = input[0];
|
||||
|
||||
if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) {
|
||||
widget.value = widget.options.values[0];
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
updateInnerWidgets() {
|
||||
@ -927,13 +1011,15 @@ export class GroupNodeHandler {
|
||||
continue;
|
||||
} else if (innerNode.type === "Reroute") {
|
||||
const rerouteLinks = this.groupData.linksFrom[old.node.index];
|
||||
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
|
||||
const node = this.innerNodes[targetNodeId];
|
||||
const input = node.inputs[targetSlot];
|
||||
if (input.widget) {
|
||||
const widget = node.widgets?.find((w) => w.name === input.widget.name);
|
||||
if (widget) {
|
||||
widget.value = newValue;
|
||||
if (rerouteLinks) {
|
||||
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
|
||||
const node = this.innerNodes[targetNodeId];
|
||||
const input = node.inputs[targetSlot];
|
||||
if (input.widget) {
|
||||
const widget = node.widgets?.find((w) => w.name === input.widget.name);
|
||||
if (widget) {
|
||||
widget.value = newValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -975,7 +1061,7 @@ export class GroupNodeHandler {
|
||||
const [, , targetNodeId, targetNodeSlot] = link;
|
||||
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
|
||||
const inputs = targetNode.inputs;
|
||||
const targetWidget = inputs?.[targetNodeSlot].widget;
|
||||
const targetWidget = inputs?.[targetNodeSlot]?.widget;
|
||||
if (!targetWidget) return;
|
||||
|
||||
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
|
||||
@ -983,13 +1069,12 @@ export class GroupNodeHandler {
|
||||
if (v == null) return;
|
||||
|
||||
const widgetName = Object.values(map)[0];
|
||||
const widget = this.node.widgets.find(w => w.name === widgetName);
|
||||
if(widget) {
|
||||
const widget = this.node.widgets.find((w) => w.name === widgetName);
|
||||
if (widget) {
|
||||
widget.value = v;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
populateWidgets() {
|
||||
if (!this.node.widgets) return;
|
||||
|
||||
@ -1080,7 +1165,7 @@ export class GroupNodeHandler {
|
||||
}
|
||||
|
||||
static getGroupData(node) {
|
||||
return node.constructor?.nodeData?.[GROUP];
|
||||
return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP];
|
||||
}
|
||||
|
||||
static isGroupNode(node) {
|
||||
@ -1112,7 +1197,7 @@ export class GroupNodeHandler {
|
||||
}
|
||||
|
||||
function addConvertToGroupOptions() {
|
||||
function addOption(options, index) {
|
||||
function addConvertOption(options, index) {
|
||||
const selected = Object.values(app.canvas.selected_nodes ?? {});
|
||||
const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n));
|
||||
options.splice(index + 1, null, {
|
||||
@ -1124,12 +1209,25 @@ function addConvertToGroupOptions() {
|
||||
});
|
||||
}
|
||||
|
||||
function addManageOption(options, index) {
|
||||
const groups = app.graph.extra?.groupNodes;
|
||||
const disabled = !groups || !Object.keys(groups).length;
|
||||
options.splice(index + 1, null, {
|
||||
content: `Manage Group Nodes`,
|
||||
disabled,
|
||||
callback: () => {
|
||||
new ManageGroupDialog(app).show();
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Add to canvas
|
||||
const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions;
|
||||
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
|
||||
const options = getCanvasMenuOptions.apply(this, arguments);
|
||||
const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length;
|
||||
addOption(options, index);
|
||||
addConvertOption(options, index);
|
||||
addManageOption(options, index + 1);
|
||||
return options;
|
||||
};
|
||||
|
||||
@ -1139,7 +1237,7 @@ function addConvertToGroupOptions() {
|
||||
const options = getNodeMenuOptions.apply(this, arguments);
|
||||
if (!GroupNodeHandler.isGroupNode(node)) {
|
||||
const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1;
|
||||
addOption(options, index);
|
||||
addConvertOption(options, index);
|
||||
}
|
||||
return options;
|
||||
};
|
||||
@ -1167,6 +1265,14 @@ const ext = {
|
||||
node[GROUP] = new GroupNodeHandler(node);
|
||||
}
|
||||
},
|
||||
async refreshComboInNodes(defs) {
|
||||
// Re-register group nodes so new ones are created with the correct options
|
||||
Object.assign(globalDefs, defs);
|
||||
const nodes = app.graph.extra?.groupNodes;
|
||||
if (nodes) {
|
||||
await GroupNodeConfig.registerFromWorkflow(nodes, {});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
app.registerExtension(ext);
|
||||
|
||||
149
comfy/web/extensions/core/groupNodeManage.css
Normal file
149
comfy/web/extensions/core/groupNodeManage.css
Normal file
@ -0,0 +1,149 @@
|
||||
.comfy-group-manage {
|
||||
background: var(--bg-color);
|
||||
color: var(--fg-color);
|
||||
padding: 0;
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
border-color: black;
|
||||
margin: 20vh auto;
|
||||
max-height: 60vh;
|
||||
}
|
||||
.comfy-group-manage-outer {
|
||||
max-height: 60vh;
|
||||
min-width: 500px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.comfy-group-manage-outer > header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
justify-content: space-between;
|
||||
background: var(--comfy-menu-bg);
|
||||
padding: 15px 20px;
|
||||
}
|
||||
.comfy-group-manage-outer > header select {
|
||||
background: var(--comfy-input-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--input-text);
|
||||
padding: 5px 10px;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.comfy-group-manage h2 {
|
||||
margin: 0;
|
||||
font-weight: normal;
|
||||
}
|
||||
.comfy-group-manage main {
|
||||
display: flex;
|
||||
overflow: hidden;
|
||||
}
|
||||
.comfy-group-manage .drag-handle {
|
||||
font-weight: bold;
|
||||
}
|
||||
.comfy-group-manage-list {
|
||||
border-right: 1px solid var(--comfy-menu-bg);
|
||||
}
|
||||
.comfy-group-manage-list ul {
|
||||
margin: 40px 0 0;
|
||||
padding: 0;
|
||||
list-style: none;
|
||||
}
|
||||
.comfy-group-manage-list-items {
|
||||
max-height: 70vh;
|
||||
overflow-y: scroll;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
.comfy-group-manage-list li {
|
||||
display: flex;
|
||||
padding: 10px 20px 10px 10px;
|
||||
cursor: pointer;
|
||||
align-items: center;
|
||||
gap: 5px;
|
||||
}
|
||||
.comfy-group-manage-list div {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.comfy-group-manage-list li:not(.selected):hover div {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.comfy-group-manage-list li.selected {
|
||||
background: var(--border-color);
|
||||
}
|
||||
.comfy-group-manage-list li span {
|
||||
opacity: 0.7;
|
||||
font-size: smaller;
|
||||
}
|
||||
.comfy-group-manage-node {
|
||||
flex: auto;
|
||||
background: var(--border-color);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.comfy-group-manage-node > div {
|
||||
overflow: auto;
|
||||
}
|
||||
.comfy-group-manage-node header {
|
||||
display: flex;
|
||||
background: var(--bg-color);
|
||||
height: 40px;
|
||||
}
|
||||
.comfy-group-manage-node header a {
|
||||
text-align: center;
|
||||
flex: auto;
|
||||
border-right: 1px solid var(--comfy-menu-bg);
|
||||
border-bottom: 1px solid var(--comfy-menu-bg);
|
||||
padding: 10px;
|
||||
cursor: pointer;
|
||||
font-size: 15px;
|
||||
}
|
||||
.comfy-group-manage-node header a:last-child {
|
||||
border-right: none;
|
||||
}
|
||||
.comfy-group-manage-node header a:not(.active):hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.comfy-group-manage-node header a.active {
|
||||
background: var(--border-color);
|
||||
border-bottom: none;
|
||||
}
|
||||
.comfy-group-manage-node-page {
|
||||
display: none;
|
||||
overflow: auto;
|
||||
}
|
||||
.comfy-group-manage-node-page.active {
|
||||
display: block;
|
||||
}
|
||||
.comfy-group-manage-node-page div {
|
||||
padding: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
.comfy-group-manage-node-page input {
|
||||
border: none;
|
||||
color: var(--input-text);
|
||||
background: var(--comfy-input-bg);
|
||||
padding: 5px 10px;
|
||||
}
|
||||
.comfy-group-manage-node-page input[type="text"] {
|
||||
flex: auto;
|
||||
}
|
||||
.comfy-group-manage-node-page label {
|
||||
display: flex;
|
||||
gap: 5px;
|
||||
align-items: center;
|
||||
}
|
||||
.comfy-group-manage footer {
|
||||
border-top: 1px solid var(--comfy-menu-bg);
|
||||
padding: 10px;
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
}
|
||||
.comfy-group-manage footer button {
|
||||
font-size: 14px;
|
||||
padding: 5px 10px;
|
||||
border-radius: 0;
|
||||
}
|
||||
.comfy-group-manage footer button:first-child {
|
||||
margin-right: auto;
|
||||
}
|
||||
422
comfy/web/extensions/core/groupNodeManage.js
Normal file
422
comfy/web/extensions/core/groupNodeManage.js
Normal file
@ -0,0 +1,422 @@
|
||||
import { $el, ComfyDialog } from "../../scripts/ui.js";
|
||||
import { DraggableList } from "../../scripts/ui/draggableList.js";
|
||||
import { addStylesheet } from "../../scripts/utils.js";
|
||||
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
|
||||
|
||||
addStylesheet(import.meta.url);
|
||||
|
||||
const ORDER = Symbol();
|
||||
|
||||
function merge(target, source) {
|
||||
if (typeof target === "object" && typeof source === "object") {
|
||||
for (const key in source) {
|
||||
const sv = source[key];
|
||||
if (typeof sv === "object") {
|
||||
let tv = target[key];
|
||||
if (!tv) tv = target[key] = {};
|
||||
merge(tv, source[key]);
|
||||
} else {
|
||||
target[key] = sv;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return target;
|
||||
}
|
||||
|
||||
export class ManageGroupDialog extends ComfyDialog {
|
||||
/** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */
|
||||
tabs = {};
|
||||
/** @type { number | null | undefined } */
|
||||
selectedNodeIndex;
|
||||
/** @type { keyof ManageGroupDialog["tabs"] } */
|
||||
selectedTab = "Inputs";
|
||||
/** @type { string | undefined } */
|
||||
selectedGroup;
|
||||
|
||||
/** @type { Record<string, Record<string, Record<string, { name?: string | undefined, visible?: boolean | undefined }>>> } */
|
||||
modifications = {};
|
||||
|
||||
get selectedNodeInnerIndex() {
|
||||
return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex;
|
||||
}
|
||||
|
||||
constructor(app) {
|
||||
super();
|
||||
this.app = app;
|
||||
this.element = $el("dialog.comfy-group-manage", {
|
||||
parent: document.body,
|
||||
});
|
||||
}
|
||||
|
||||
changeTab(tab) {
|
||||
this.tabs[this.selectedTab].tab.classList.remove("active");
|
||||
this.tabs[this.selectedTab].page.classList.remove("active");
|
||||
this.tabs[tab].tab.classList.add("active");
|
||||
this.tabs[tab].page.classList.add("active");
|
||||
this.selectedTab = tab;
|
||||
}
|
||||
|
||||
changeNode(index, force) {
|
||||
if (!force && this.selectedNodeIndex === index) return;
|
||||
|
||||
if (this.selectedNodeIndex != null) {
|
||||
this.nodeItems[this.selectedNodeIndex].classList.remove("selected");
|
||||
}
|
||||
this.nodeItems[index].classList.add("selected");
|
||||
this.selectedNodeIndex = index;
|
||||
|
||||
if (!this.buildInputsPage() && this.selectedTab === "Inputs") {
|
||||
this.changeTab("Widgets");
|
||||
}
|
||||
if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") {
|
||||
this.changeTab("Outputs");
|
||||
}
|
||||
if (!this.buildOutputsPage() && this.selectedTab === "Outputs") {
|
||||
this.changeTab("Inputs");
|
||||
}
|
||||
|
||||
this.changeTab(this.selectedTab);
|
||||
}
|
||||
|
||||
getGroupData() {
|
||||
this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup];
|
||||
this.groupNodeDef = this.groupNodeType.nodeData;
|
||||
this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType);
|
||||
}
|
||||
|
||||
changeGroup(group, reset = true) {
|
||||
this.selectedGroup = group;
|
||||
this.getGroupData();
|
||||
|
||||
const nodes = this.groupData.nodeData.nodes;
|
||||
this.nodeItems = nodes.map((n, i) =>
|
||||
$el(
|
||||
"li.draggable-item",
|
||||
{
|
||||
dataset: {
|
||||
nodeindex: n.index + "",
|
||||
},
|
||||
onclick: () => {
|
||||
this.changeNode(i);
|
||||
},
|
||||
},
|
||||
[
|
||||
$el("span.drag-handle"),
|
||||
$el(
|
||||
"div",
|
||||
{
|
||||
textContent: n.title ?? n.type,
|
||||
},
|
||||
n.title
|
||||
? $el("span", {
|
||||
textContent: n.type,
|
||||
})
|
||||
: []
|
||||
),
|
||||
]
|
||||
)
|
||||
);
|
||||
|
||||
this.innerNodesList.replaceChildren(...this.nodeItems);
|
||||
|
||||
if (reset) {
|
||||
this.selectedNodeIndex = null;
|
||||
this.changeNode(0);
|
||||
} else {
|
||||
const items = this.draggable.getAllItems();
|
||||
let index = items.findIndex(item => item.classList.contains("selected"));
|
||||
if(index === -1) index = this.selectedNodeIndex;
|
||||
this.changeNode(index, true);
|
||||
}
|
||||
|
||||
const ordered = [...nodes];
|
||||
this.draggable?.dispose();
|
||||
this.draggable = new DraggableList(this.innerNodesList, "li");
|
||||
this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => {
|
||||
if (oldPosition === newPosition) return;
|
||||
ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]);
|
||||
for (let i = 0; i < ordered.length; i++) {
|
||||
this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
storeModification({ nodeIndex, section, prop, value }) {
|
||||
const groupMod = (this.modifications[this.selectedGroup] ??= {});
|
||||
const nodesMod = (groupMod.nodes ??= {});
|
||||
const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {});
|
||||
const typeMod = (nodeMod[section] ??= {});
|
||||
if (typeof value === "object") {
|
||||
const objMod = (typeMod[prop] ??= {});
|
||||
Object.assign(objMod, value);
|
||||
} else {
|
||||
typeMod[prop] = value;
|
||||
}
|
||||
}
|
||||
|
||||
getEditElement(section, prop, value, placeholder, checked, checkable = true) {
|
||||
if (value === placeholder) value = "";
|
||||
|
||||
const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop];
|
||||
if (mods) {
|
||||
if (mods.name != null) {
|
||||
value = mods.name;
|
||||
}
|
||||
if (mods.visible != null) {
|
||||
checked = mods.visible;
|
||||
}
|
||||
}
|
||||
|
||||
return $el("div", [
|
||||
$el("input", {
|
||||
value,
|
||||
placeholder,
|
||||
type: "text",
|
||||
onchange: (e) => {
|
||||
this.storeModification({ section, prop, value: { name: e.target.value } });
|
||||
},
|
||||
}),
|
||||
$el("label", { textContent: "Visible" }, [
|
||||
$el("input", {
|
||||
type: "checkbox",
|
||||
checked,
|
||||
disabled: !checkable,
|
||||
onchange: (e) => {
|
||||
this.storeModification({ section, prop, value: { visible: !!e.target.checked } });
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
}
|
||||
|
||||
buildWidgetsPage() {
|
||||
const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex];
|
||||
const items = Object.keys(widgets ?? {});
|
||||
const type = app.graph.extra.groupNodes[this.selectedGroup];
|
||||
const config = type.config?.[this.selectedNodeInnerIndex]?.input;
|
||||
this.widgetsPage.replaceChildren(
|
||||
...items.map((oldName) => {
|
||||
return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false);
|
||||
})
|
||||
);
|
||||
return !!items.length;
|
||||
}
|
||||
|
||||
buildInputsPage() {
|
||||
const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex];
|
||||
const items = Object.keys(inputs ?? {});
|
||||
const type = app.graph.extra.groupNodes[this.selectedGroup];
|
||||
const config = type.config?.[this.selectedNodeInnerIndex]?.input;
|
||||
this.inputsPage.replaceChildren(
|
||||
...items
|
||||
.map((oldName) => {
|
||||
let value = inputs[oldName];
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false);
|
||||
})
|
||||
.filter(Boolean)
|
||||
);
|
||||
return !!items.length;
|
||||
}
|
||||
|
||||
buildOutputsPage() {
|
||||
const nodes = this.groupData.nodeData.nodes;
|
||||
const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]);
|
||||
const outputs = innerNodeDef?.output ?? [];
|
||||
const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex];
|
||||
|
||||
const type = app.graph.extra.groupNodes[this.selectedGroup];
|
||||
const config = type.config?.[this.selectedNodeInnerIndex]?.output;
|
||||
const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex];
|
||||
const checkable = node.type !== "PrimitiveNode";
|
||||
this.outputsPage.replaceChildren(
|
||||
...outputs
|
||||
.map((type, slot) => {
|
||||
const groupOutputIndex = groupOutputs?.[slot];
|
||||
const oldName = innerNodeDef.output_name?.[slot] ?? type;
|
||||
let value = config?.[slot]?.name;
|
||||
const visible = config?.[slot]?.visible || groupOutputIndex != null;
|
||||
if (!value || value === oldName) {
|
||||
value = "";
|
||||
}
|
||||
return this.getEditElement("output", slot, value, oldName, visible, checkable);
|
||||
})
|
||||
.filter(Boolean)
|
||||
);
|
||||
return !!outputs.length;
|
||||
}
|
||||
|
||||
show(type) {
|
||||
const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b));
|
||||
|
||||
this.innerNodesList = $el("ul.comfy-group-manage-list-items");
|
||||
this.widgetsPage = $el("section.comfy-group-manage-node-page");
|
||||
this.inputsPage = $el("section.comfy-group-manage-node-page");
|
||||
this.outputsPage = $el("section.comfy-group-manage-node-page");
|
||||
const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]);
|
||||
|
||||
this.tabs = [
|
||||
["Inputs", this.inputsPage],
|
||||
["Widgets", this.widgetsPage],
|
||||
["Outputs", this.outputsPage],
|
||||
].reduce((p, [name, page]) => {
|
||||
p[name] = {
|
||||
tab: $el("a", {
|
||||
onclick: () => {
|
||||
this.changeTab(name);
|
||||
},
|
||||
textContent: name,
|
||||
}),
|
||||
page,
|
||||
};
|
||||
return p;
|
||||
}, {});
|
||||
|
||||
const outer = $el("div.comfy-group-manage-outer", [
|
||||
$el("header", [
|
||||
$el("h2", "Group Nodes"),
|
||||
$el(
|
||||
"select",
|
||||
{
|
||||
onchange: (e) => {
|
||||
this.changeGroup(e.target.value);
|
||||
},
|
||||
},
|
||||
groupNodes.map((g) =>
|
||||
$el("option", {
|
||||
textContent: g,
|
||||
selected: "workflow/" + g === type,
|
||||
value: g,
|
||||
})
|
||||
)
|
||||
),
|
||||
]),
|
||||
$el("main", [
|
||||
$el("section.comfy-group-manage-list", this.innerNodesList),
|
||||
$el("section.comfy-group-manage-node", [
|
||||
$el(
|
||||
"header",
|
||||
Object.values(this.tabs).map((t) => t.tab)
|
||||
),
|
||||
pages,
|
||||
]),
|
||||
]),
|
||||
$el("footer", [
|
||||
$el(
|
||||
"button.comfy-btn",
|
||||
{
|
||||
onclick: (e) => {
|
||||
const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup);
|
||||
if (node) {
|
||||
alert("This group node is in use in the current workflow, please first remove these.");
|
||||
return;
|
||||
}
|
||||
if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) {
|
||||
delete app.graph.extra.groupNodes[this.selectedGroup];
|
||||
LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup);
|
||||
}
|
||||
this.show();
|
||||
},
|
||||
},
|
||||
"Delete Group Node"
|
||||
),
|
||||
$el(
|
||||
"button.comfy-btn",
|
||||
{
|
||||
onclick: async () => {
|
||||
let nodesByType;
|
||||
let recreateNodes = [];
|
||||
const types = {};
|
||||
for (const g in this.modifications) {
|
||||
const type = app.graph.extra.groupNodes[g];
|
||||
let config = (type.config ??= {});
|
||||
|
||||
let nodeMods = this.modifications[g]?.nodes;
|
||||
if (nodeMods) {
|
||||
const keys = Object.keys(nodeMods);
|
||||
if (nodeMods[keys[0]][ORDER]) {
|
||||
// If any node is reordered, they will all need sequencing
|
||||
const orderedNodes = [];
|
||||
const orderedMods = {};
|
||||
const orderedConfig = {};
|
||||
|
||||
for (const n of keys) {
|
||||
const order = nodeMods[n][ORDER].order;
|
||||
orderedNodes[order] = type.nodes[+n];
|
||||
orderedMods[order] = nodeMods[n];
|
||||
orderedNodes[order].index = order;
|
||||
}
|
||||
|
||||
// Rewrite links
|
||||
for (const l of type.links) {
|
||||
if (l[0] != null) l[0] = type.nodes[l[0]].index;
|
||||
if (l[2] != null) l[2] = type.nodes[l[2]].index;
|
||||
}
|
||||
|
||||
// Rewrite externals
|
||||
if (type.external) {
|
||||
for (const ext of type.external) {
|
||||
ext[0] = type.nodes[ext[0]];
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrite modifications
|
||||
for (const id of keys) {
|
||||
if (config[id]) {
|
||||
orderedConfig[type.nodes[id].index] = config[id];
|
||||
}
|
||||
delete config[id];
|
||||
}
|
||||
|
||||
type.nodes = orderedNodes;
|
||||
nodeMods = orderedMods;
|
||||
type.config = config = orderedConfig;
|
||||
}
|
||||
|
||||
merge(config, nodeMods);
|
||||
}
|
||||
|
||||
types[g] = type;
|
||||
|
||||
if (!nodesByType) {
|
||||
nodesByType = app.graph._nodes.reduce((p, n) => {
|
||||
p[n.type] ??= [];
|
||||
p[n.type].push(n);
|
||||
return p;
|
||||
}, {});
|
||||
}
|
||||
|
||||
const nodes = nodesByType["workflow/" + g];
|
||||
if (nodes) recreateNodes.push(...nodes);
|
||||
}
|
||||
|
||||
await GroupNodeConfig.registerFromWorkflow(types, {});
|
||||
|
||||
for (const node of recreateNodes) {
|
||||
node.recreate();
|
||||
}
|
||||
|
||||
this.modifications = {};
|
||||
this.app.graph.setDirtyCanvas(true, true);
|
||||
this.changeGroup(this.selectedGroup, false);
|
||||
},
|
||||
},
|
||||
"Save"
|
||||
),
|
||||
$el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"),
|
||||
]),
|
||||
]);
|
||||
|
||||
this.element.replaceChildren(outer);
|
||||
this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]);
|
||||
this.element.showModal();
|
||||
|
||||
this.element.addEventListener("close", () => {
|
||||
this.draggable?.dispose();
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { ComfyDialog, $el } from "../../scripts/ui.js";
|
||||
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
|
||||
|
||||
@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
|
||||
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle
|
||||
|
||||
const id = "Comfy.NodeTemplates";
|
||||
const file = "comfy.templates.json";
|
||||
|
||||
class ManageTemplates extends ComfyDialog {
|
||||
constructor() {
|
||||
super();
|
||||
this.load().then((v) => {
|
||||
this.templates = v;
|
||||
});
|
||||
|
||||
this.element.classList.add("comfy-manage-templates");
|
||||
this.templates = this.load();
|
||||
this.draggedEl = null;
|
||||
this.saveVisualCue = null;
|
||||
this.emptyImg = new Image();
|
||||
this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=';
|
||||
this.emptyImg.src = "data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=";
|
||||
|
||||
this.importInput = $el("input", {
|
||||
type: "file",
|
||||
@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog {
|
||||
return btns;
|
||||
}
|
||||
|
||||
load() {
|
||||
const templates = localStorage.getItem(id);
|
||||
if (templates) {
|
||||
return JSON.parse(templates);
|
||||
async load() {
|
||||
let templates = [];
|
||||
if (app.storageLocation === "server") {
|
||||
if (app.isNewUserSession) {
|
||||
// New user so migrate existing templates
|
||||
const json = localStorage.getItem(id);
|
||||
if (json) {
|
||||
templates = JSON.parse(json);
|
||||
}
|
||||
await api.storeUserData(file, json, { stringify: false });
|
||||
} else {
|
||||
const res = await api.getUserData(file);
|
||||
if (res.status === 200) {
|
||||
try {
|
||||
templates = await res.json();
|
||||
} catch (error) {
|
||||
}
|
||||
} else if (res.status !== 404) {
|
||||
console.error(res.status + " " + res.statusText);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return [];
|
||||
const json = localStorage.getItem(id);
|
||||
if (json) {
|
||||
templates = JSON.parse(json);
|
||||
}
|
||||
}
|
||||
|
||||
return templates ?? [];
|
||||
}
|
||||
|
||||
store() {
|
||||
localStorage.setItem(id, JSON.stringify(this.templates));
|
||||
async store() {
|
||||
if(app.storageLocation === "server") {
|
||||
const templates = JSON.stringify(this.templates, undefined, 4);
|
||||
localStorage.setItem(id, templates); // Backwards compatibility
|
||||
try {
|
||||
await api.storeUserData(file, templates, { stringify: false });
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
alert(error.message);
|
||||
}
|
||||
} else {
|
||||
localStorage.setItem(id, JSON.stringify(this.templates));
|
||||
}
|
||||
}
|
||||
|
||||
async importAll() {
|
||||
@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog {
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
var importFile = JSON.parse(reader.result);
|
||||
if (importFile && importFile?.templates) {
|
||||
const importFile = JSON.parse(reader.result);
|
||||
if (importFile?.templates) {
|
||||
for (const template of importFile.templates) {
|
||||
if (template?.name && template?.data) {
|
||||
this.templates.push(template);
|
||||
}
|
||||
}
|
||||
this.store();
|
||||
await this.store();
|
||||
}
|
||||
};
|
||||
await reader.readAsText(file);
|
||||
@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog {
|
||||
e.currentTarget.style.border = "1px dashed transparent";
|
||||
e.currentTarget.removeAttribute("draggable");
|
||||
|
||||
// rearrange the elements in the localStorage
|
||||
// rearrange the elements
|
||||
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
|
||||
var prev_i = el.dataset.id;
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js"
|
||||
|
||||
const MAX_HISTORY = 50;
|
||||
|
||||
@ -15,6 +16,7 @@ function checkState() {
|
||||
}
|
||||
activeState = clone(currentState);
|
||||
redo.length = 0;
|
||||
api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState }));
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,7 +94,7 @@ const undoRedo = async (e) => {
|
||||
};
|
||||
|
||||
const bindInput = (activeEl) => {
|
||||
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
|
||||
if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") {
|
||||
for (const evt of ["change", "input", "blur"]) {
|
||||
if (`on${evt}` in activeEl) {
|
||||
const listener = () => {
|
||||
@ -106,15 +108,23 @@ const bindInput = (activeEl) => {
|
||||
}
|
||||
};
|
||||
|
||||
let keyIgnored = false;
|
||||
window.addEventListener(
|
||||
"keydown",
|
||||
(e) => {
|
||||
requestAnimationFrame(async () => {
|
||||
const activeEl = document.activeElement;
|
||||
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
|
||||
// Ignore events on inputs, they have their native history
|
||||
return;
|
||||
let activeEl;
|
||||
// If we are auto queue in change mode then we do want to trigger on inputs
|
||||
if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") {
|
||||
activeEl = document.activeElement;
|
||||
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
|
||||
// Ignore events on inputs, they have their native history
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta";
|
||||
if (keyIgnored) return;
|
||||
|
||||
// Check if this is a ctrl+z ctrl+y
|
||||
if (await undoRedo(e)) return;
|
||||
@ -127,11 +137,23 @@ window.addEventListener(
|
||||
true
|
||||
);
|
||||
|
||||
window.addEventListener("keyup", (e) => {
|
||||
if (keyIgnored) {
|
||||
keyIgnored = false;
|
||||
checkState();
|
||||
}
|
||||
});
|
||||
|
||||
// Handle clicking DOM elements (e.g. widgets)
|
||||
window.addEventListener("mouseup", () => {
|
||||
checkState();
|
||||
});
|
||||
|
||||
// Handle prompt queue event for dynamic widget changes
|
||||
api.addEventListener("promptQueued", () => {
|
||||
checkState();
|
||||
});
|
||||
|
||||
// Handle litegraph clicks
|
||||
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
|
||||
LGraphCanvas.prototype.processMouseUp = function (e) {
|
||||
@ -145,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) {
|
||||
checkState();
|
||||
return v;
|
||||
};
|
||||
|
||||
// Handle litegraph context menu for COMBO widgets
|
||||
const close = LiteGraph.ContextMenu.prototype.close;
|
||||
LiteGraph.ContextMenu.prototype.close = function(e) {
|
||||
const v = close.apply(this, arguments);
|
||||
checkState();
|
||||
return v;
|
||||
}
|
||||
@ -16,5 +16,33 @@
|
||||
window.graph = app.graph;
|
||||
</script>
|
||||
</head>
|
||||
<body class="litegraph"></body>
|
||||
<body class="litegraph">
|
||||
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
|
||||
<main class="comfy-user-selection-inner">
|
||||
<h1>ComfyUI</h1>
|
||||
<form>
|
||||
<section>
|
||||
<label>New user:
|
||||
<input placeholder="Enter a username" />
|
||||
</label>
|
||||
</section>
|
||||
<div class="comfy-user-existing">
|
||||
<span class="or-separator">OR</span>
|
||||
<section>
|
||||
<label>
|
||||
Existing user:
|
||||
<select>
|
||||
<option hidden disabled selected value> Select a user </option>
|
||||
</select>
|
||||
</label>
|
||||
</section>
|
||||
</div>
|
||||
<footer>
|
||||
<span class="comfy-user-error"> </span>
|
||||
<button class="comfy-btn comfy-user-button-next">Next</button>
|
||||
</footer>
|
||||
</form>
|
||||
</main>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"/*": ["./*"]
|
||||
}
|
||||
},
|
||||
"lib": ["DOM", "ES2022"]
|
||||
},
|
||||
"include": ["."]
|
||||
}
|
||||
|
||||
@ -11496,7 +11496,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
timeout_close = setTimeout(function() {
|
||||
dialog.close();
|
||||
}, 500);
|
||||
}, typeof options.hide_on_mouse_leave === "number" ? options.hide_on_mouse_leave : 500);
|
||||
});
|
||||
// if filtering, check focus changed to comboboxes and prevent closing
|
||||
if (options.do_type_filter){
|
||||
|
||||
@ -12,6 +12,13 @@ class ComfyApi extends EventTarget {
|
||||
}
|
||||
|
||||
fetchApi(route, options) {
|
||||
if (!options) {
|
||||
options = {};
|
||||
}
|
||||
if (!options.headers) {
|
||||
options.headers = {};
|
||||
}
|
||||
options.headers["Comfy-User"] = this.user;
|
||||
return fetch(this.apiURL(route), options);
|
||||
}
|
||||
|
||||
@ -315,6 +322,99 @@ class ComfyApi extends EventTarget {
|
||||
async interrupt() {
|
||||
await this.#postItem("interrupt", null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets user configuration data and where data should be stored
|
||||
* @returns { Promise<{ storage: "server" | "browser", users?: Promise<string, unknown>, migrated?: boolean }> }
|
||||
*/
|
||||
async getUserConfig() {
|
||||
return (await this.fetchApi("/users")).json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new user
|
||||
* @param { string } username
|
||||
* @returns The fetch response
|
||||
*/
|
||||
createUser(username) {
|
||||
return this.fetchApi("/users", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ username }),
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all setting values for the current user
|
||||
* @returns { Promise<string, unknown> } A dictionary of id -> value
|
||||
*/
|
||||
async getSettings() {
|
||||
return (await this.fetchApi("/settings")).json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a setting for the current user
|
||||
* @param { string } id The id of the setting to fetch
|
||||
* @returns { Promise<unknown> } The setting value
|
||||
*/
|
||||
async getSetting(id) {
|
||||
return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores a dictionary of settings for the current user
|
||||
* @param { Record<string, unknown> } settings Dictionary of setting id -> value to save
|
||||
* @returns { Promise<void> }
|
||||
*/
|
||||
async storeSettings(settings) {
|
||||
return this.fetchApi(`/settings`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(settings)
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores a setting for the current user
|
||||
* @param { string } id The id of the setting to update
|
||||
* @param { unknown } value The value of the setting
|
||||
* @returns { Promise<void> }
|
||||
*/
|
||||
async storeSetting(id, value) {
|
||||
return this.fetchApi(`/settings/${encodeURIComponent(id)}`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(value)
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a user data file for the current user
|
||||
* @param { string } file The name of the userdata file to load
|
||||
* @param { RequestInit } [options]
|
||||
* @returns { Promise<unknown> } The fetch response object
|
||||
*/
|
||||
async getUserData(file, options) {
|
||||
return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores a user data file for the current user
|
||||
* @param { string } file The name of the userdata file to save
|
||||
* @param { unknown } data The data to save to the file
|
||||
* @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options]
|
||||
* @returns { Promise<void> }
|
||||
*/
|
||||
async storeUserData(file, data, options = { stringify: true, throwOnError: true }) {
|
||||
const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, {
|
||||
method: "POST",
|
||||
body: options?.stringify ? JSON.stringify(data) : data,
|
||||
...options,
|
||||
});
|
||||
if (resp.status !== 200) {
|
||||
throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const api = new ComfyApi();
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { ComfyLogging } from "./logging.js";
|
||||
import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyWidgets, initWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
@ -269,6 +269,71 @@ export class ComfyApp {
|
||||
* @param {*} node The node to add the menu handler
|
||||
*/
|
||||
#addNodeContextMenuHandler(node) {
|
||||
function getCopyImageOption(img) {
|
||||
if (typeof window.ClipboardItem === "undefined") return [];
|
||||
return [
|
||||
{
|
||||
content: "Copy Image",
|
||||
callback: async () => {
|
||||
const url = new URL(img.src);
|
||||
url.searchParams.delete("preview");
|
||||
|
||||
const writeImage = async (blob) => {
|
||||
await navigator.clipboard.write([
|
||||
new ClipboardItem({
|
||||
[blob.type]: blob,
|
||||
}),
|
||||
]);
|
||||
};
|
||||
|
||||
try {
|
||||
const data = await fetch(url);
|
||||
const blob = await data.blob();
|
||||
try {
|
||||
await writeImage(blob);
|
||||
} catch (error) {
|
||||
// Chrome seems to only support PNG on write, convert and try again
|
||||
if (blob.type !== "image/png") {
|
||||
const canvas = $el("canvas", {
|
||||
width: img.naturalWidth,
|
||||
height: img.naturalHeight,
|
||||
});
|
||||
const ctx = canvas.getContext("2d");
|
||||
let image;
|
||||
if (typeof window.createImageBitmap === "undefined") {
|
||||
image = new Image();
|
||||
const p = new Promise((resolve, reject) => {
|
||||
image.onload = resolve;
|
||||
image.onerror = reject;
|
||||
}).finally(() => {
|
||||
URL.revokeObjectURL(image.src);
|
||||
});
|
||||
image.src = URL.createObjectURL(blob);
|
||||
await p;
|
||||
} else {
|
||||
image = await createImageBitmap(blob);
|
||||
}
|
||||
try {
|
||||
ctx.drawImage(image, 0, 0);
|
||||
canvas.toBlob(writeImage, "image/png");
|
||||
} finally {
|
||||
if (typeof image.close === "function") {
|
||||
image.close();
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
alert("Error copying image: " + (error.message ?? error));
|
||||
}
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
node.prototype.getExtraMenuOptions = function (_, options) {
|
||||
if (this.imgs) {
|
||||
// If this node has images then we add an open in new tab item
|
||||
@ -286,16 +351,17 @@ export class ComfyApp {
|
||||
content: "Open Image",
|
||||
callback: () => {
|
||||
let url = new URL(img.src);
|
||||
url.searchParams.delete('preview');
|
||||
window.open(url, "_blank")
|
||||
url.searchParams.delete("preview");
|
||||
window.open(url, "_blank");
|
||||
},
|
||||
},
|
||||
...getCopyImageOption(img),
|
||||
{
|
||||
content: "Save Image",
|
||||
callback: () => {
|
||||
const a = document.createElement("a");
|
||||
let url = new URL(img.src);
|
||||
url.searchParams.delete('preview');
|
||||
url.searchParams.delete("preview");
|
||||
a.href = url;
|
||||
a.setAttribute("download", new URLSearchParams(url.search).get("filename"));
|
||||
document.body.append(a);
|
||||
@ -308,33 +374,41 @@ export class ComfyApp {
|
||||
}
|
||||
|
||||
options.push({
|
||||
content: "Bypass",
|
||||
callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); }
|
||||
});
|
||||
content: "Bypass",
|
||||
callback: (obj) => {
|
||||
if (this.mode === 4) this.mode = 0;
|
||||
else this.mode = 4;
|
||||
this.graph.change();
|
||||
},
|
||||
});
|
||||
|
||||
// prevent conflict of clipspace content
|
||||
if(!ComfyApp.clipspace_return_node) {
|
||||
if (!ComfyApp.clipspace_return_node) {
|
||||
options.push({
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => { ComfyApp.copyToClipspace(this); }
|
||||
});
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => {
|
||||
ComfyApp.copyToClipspace(this);
|
||||
},
|
||||
});
|
||||
|
||||
if(ComfyApp.clipspace != null) {
|
||||
if (ComfyApp.clipspace != null) {
|
||||
options.push({
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => { ComfyApp.pasteFromClipspace(this); }
|
||||
});
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => {
|
||||
ComfyApp.pasteFromClipspace(this);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if(ComfyApp.isImageNode(this)) {
|
||||
if (ComfyApp.isImageNode(this)) {
|
||||
options.push({
|
||||
content: "Open in MaskEditor",
|
||||
callback: (obj) => {
|
||||
ComfyApp.copyToClipspace(this);
|
||||
ComfyApp.clipspace_return_node = this;
|
||||
ComfyApp.open_maskeditor();
|
||||
}
|
||||
});
|
||||
content: "Open in MaskEditor",
|
||||
callback: (obj) => {
|
||||
ComfyApp.copyToClipspace(this);
|
||||
ComfyApp.clipspace_return_node = this;
|
||||
ComfyApp.open_maskeditor();
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1291,10 +1365,92 @@ export class ComfyApp {
|
||||
await Promise.all(extensionPromises);
|
||||
}
|
||||
|
||||
async #migrateSettings() {
|
||||
this.isNewUserSession = true;
|
||||
// Store all current settings
|
||||
const settings = Object.keys(this.ui.settings).reduce((p, n) => {
|
||||
const v = localStorage[`Comfy.Settings.${n}`];
|
||||
if (v) {
|
||||
try {
|
||||
p[n] = JSON.parse(v);
|
||||
} catch (error) {}
|
||||
}
|
||||
return p;
|
||||
}, {});
|
||||
|
||||
await api.storeSettings(settings);
|
||||
}
|
||||
|
||||
async #setUser() {
|
||||
const userConfig = await api.getUserConfig();
|
||||
this.storageLocation = userConfig.storage;
|
||||
if (typeof userConfig.migrated == "boolean") {
|
||||
// Single user mode migrated true/false for if the default user is created
|
||||
if (!userConfig.migrated && this.storageLocation === "server") {
|
||||
// Default user not created yet
|
||||
await this.#migrateSettings();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
this.multiUserServer = true;
|
||||
let user = localStorage["Comfy.userId"];
|
||||
const users = userConfig.users ?? {};
|
||||
if (!user || !users[user]) {
|
||||
// This will rarely be hit so move the loading to on demand
|
||||
const { UserSelectionScreen } = await import("./ui/userSelection.js");
|
||||
|
||||
this.ui.menuContainer.style.display = "none";
|
||||
const { userId, username, created } = await new UserSelectionScreen().show(users, user);
|
||||
this.ui.menuContainer.style.display = "";
|
||||
|
||||
user = userId;
|
||||
localStorage["Comfy.userName"] = username;
|
||||
localStorage["Comfy.userId"] = user;
|
||||
|
||||
if (created) {
|
||||
api.user = user;
|
||||
await this.#migrateSettings();
|
||||
}
|
||||
}
|
||||
|
||||
api.user = user;
|
||||
|
||||
this.ui.settings.addSetting({
|
||||
id: "Comfy.SwitchUser",
|
||||
name: "Switch User",
|
||||
type: (name) => {
|
||||
let currentUser = localStorage["Comfy.userName"];
|
||||
if (currentUser) {
|
||||
currentUser = ` (${currentUser})`;
|
||||
}
|
||||
return $el("tr", [
|
||||
$el("td", [
|
||||
$el("label", {
|
||||
textContent: name,
|
||||
}),
|
||||
]),
|
||||
$el("td", [
|
||||
$el("button", {
|
||||
textContent: name + (currentUser ?? ""),
|
||||
onclick: () => {
|
||||
delete localStorage["Comfy.userId"];
|
||||
delete localStorage["Comfy.userName"];
|
||||
window.location.reload();
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up the app on the page
|
||||
*/
|
||||
async setup() {
|
||||
await this.#setUser();
|
||||
await this.ui.settings.load();
|
||||
await this.#loadExtensions();
|
||||
|
||||
// Create and mount the LiteGraph in the DOM
|
||||
@ -1338,6 +1494,7 @@ export class ComfyApp {
|
||||
|
||||
await this.#invokeExtensionsAsync("init");
|
||||
await this.registerNodes();
|
||||
initWidgets(this);
|
||||
|
||||
// Load previous workflow
|
||||
let restored = false;
|
||||
@ -1692,6 +1849,14 @@ export class ComfyApp {
|
||||
*/
|
||||
async graphToPrompt() {
|
||||
for (const outerNode of this.graph.computeExecutionOrder(false)) {
|
||||
if (outerNode.widgets) {
|
||||
for (const widget of outerNode.widgets) {
|
||||
// Allow widgets to run callbacks before a prompt has been queued
|
||||
// e.g. random seed before every gen
|
||||
widget.beforeQueued?.();
|
||||
}
|
||||
}
|
||||
|
||||
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
|
||||
for (const node of innerNodes) {
|
||||
if (node.isVirtualNode) {
|
||||
@ -1903,6 +2068,7 @@ export class ComfyApp {
|
||||
} finally {
|
||||
this.#processingQueue = false;
|
||||
}
|
||||
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2046,6 +2212,8 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -239,7 +239,8 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
|
||||
node.flags?.collapsed ||
|
||||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
|
||||
widget.computedHeight <= 0 ||
|
||||
widget.type === "converted-widget";
|
||||
widget.type === "converted-widget"||
|
||||
widget.type === "hidden";
|
||||
element.hidden = hidden;
|
||||
element.style.display = hidden ? "none" : null;
|
||||
if (hidden) {
|
||||
|
||||
@ -269,6 +269,9 @@ export class ComfyLogging {
|
||||
id: settingId,
|
||||
name: settingId,
|
||||
defaultValue: true,
|
||||
onChange: (value) => {
|
||||
this.enabled = value;
|
||||
},
|
||||
type: (name, setter, value) => {
|
||||
return $el("tr", [
|
||||
$el("td", [
|
||||
@ -283,7 +286,7 @@ export class ComfyLogging {
|
||||
type: "checkbox",
|
||||
checked: value,
|
||||
onchange: (event) => {
|
||||
setter((this.enabled = event.target.checked));
|
||||
setter(event.target.checked);
|
||||
},
|
||||
}),
|
||||
$el("button", {
|
||||
|
||||
@ -1,5 +1,23 @@
|
||||
import {api} from "./api.js";
|
||||
import { api } from "./api.js";
|
||||
import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js";
|
||||
import { toggleSwitch } from "./ui/toggleSwitch.js";
|
||||
import { ComfySettingsDialog } from "./ui/settings.js";
|
||||
|
||||
export const ComfyDialog = _ComfyDialog;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param { string } tag HTML Element Tag and optional classes e.g. div.class1.class2
|
||||
* @param { string | Element | Element[] | {
|
||||
* parent?: Element,
|
||||
* $?: (el: Element) => void,
|
||||
* dataset?: DOMStringMap,
|
||||
* style?: CSSStyleDeclaration,
|
||||
* for?: string
|
||||
* } | undefined } propsOrChildren
|
||||
* @param { Element[] | undefined } [children]
|
||||
* @returns
|
||||
*/
|
||||
export function $el(tag, propsOrChildren, children) {
|
||||
const split = tag.split(".");
|
||||
const element = document.createElement(split.shift());
|
||||
@ -8,6 +26,11 @@ export function $el(tag, propsOrChildren, children) {
|
||||
}
|
||||
|
||||
if (propsOrChildren) {
|
||||
if (typeof propsOrChildren === "string") {
|
||||
propsOrChildren = { textContent: propsOrChildren };
|
||||
} else if (propsOrChildren instanceof Element) {
|
||||
propsOrChildren = [propsOrChildren];
|
||||
}
|
||||
if (Array.isArray(propsOrChildren)) {
|
||||
element.append(...propsOrChildren);
|
||||
} else {
|
||||
@ -31,7 +54,7 @@ export function $el(tag, propsOrChildren, children) {
|
||||
|
||||
Object.assign(element, propsOrChildren);
|
||||
if (children) {
|
||||
element.append(...children);
|
||||
element.append(...(children instanceof Array ? children : [children]));
|
||||
}
|
||||
|
||||
if (parent) {
|
||||
@ -167,267 +190,6 @@ function dragElement(dragEl, settings) {
|
||||
}
|
||||
}
|
||||
|
||||
export class ComfyDialog {
|
||||
constructor() {
|
||||
this.element = $el("div.comfy-modal", {parent: document.body}, [
|
||||
$el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]),
|
||||
]);
|
||||
}
|
||||
|
||||
createButtons() {
|
||||
return [
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
onclick: () => this.close(),
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
||||
close() {
|
||||
this.element.style.display = "none";
|
||||
}
|
||||
|
||||
show(html) {
|
||||
if (typeof html === "string") {
|
||||
this.textElement.innerHTML = html;
|
||||
} else {
|
||||
this.textElement.replaceChildren(html);
|
||||
}
|
||||
this.element.style.display = "flex";
|
||||
}
|
||||
}
|
||||
|
||||
class ComfySettingsDialog extends ComfyDialog {
|
||||
constructor() {
|
||||
super();
|
||||
this.element = $el("dialog", {
|
||||
id: "comfy-settings-dialog",
|
||||
parent: document.body,
|
||||
}, [
|
||||
$el("table.comfy-modal-content.comfy-table", [
|
||||
$el("caption", {textContent: "Settings"}),
|
||||
$el("tbody", {$: (tbody) => (this.textElement = tbody)}),
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
style: {
|
||||
cursor: "pointer",
|
||||
},
|
||||
onclick: () => {
|
||||
this.element.close();
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
this.settings = [];
|
||||
}
|
||||
|
||||
getSettingValue(id, defaultValue) {
|
||||
const settingId = "Comfy.Settings." + id;
|
||||
const v = localStorage[settingId];
|
||||
return v == null ? defaultValue : JSON.parse(v);
|
||||
}
|
||||
|
||||
setSettingValue(id, value) {
|
||||
const settingId = "Comfy.Settings." + id;
|
||||
localStorage[settingId] = JSON.stringify(value);
|
||||
}
|
||||
|
||||
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
|
||||
if (!id) {
|
||||
throw new Error("Settings must have an ID");
|
||||
}
|
||||
|
||||
if (this.settings.find((s) => s.id === id)) {
|
||||
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
|
||||
}
|
||||
|
||||
const settingId = `Comfy.Settings.${id}`;
|
||||
const v = localStorage[settingId];
|
||||
let value = v == null ? defaultValue : JSON.parse(v);
|
||||
|
||||
// Trigger initial setting of value
|
||||
if (onChange) {
|
||||
onChange(value, undefined);
|
||||
}
|
||||
|
||||
this.settings.push({
|
||||
render: () => {
|
||||
const setter = (v) => {
|
||||
if (onChange) {
|
||||
onChange(v, value);
|
||||
}
|
||||
localStorage[settingId] = JSON.stringify(v);
|
||||
value = v;
|
||||
};
|
||||
value = this.getSettingValue(id, defaultValue);
|
||||
|
||||
let element;
|
||||
const htmlID = id.replaceAll(".", "-");
|
||||
|
||||
const labelCell = $el("td", [
|
||||
$el("label", {
|
||||
for: htmlID,
|
||||
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
|
||||
textContent: name,
|
||||
})
|
||||
]);
|
||||
|
||||
if (typeof type === "function") {
|
||||
element = type(name, setter, value, attrs);
|
||||
} else {
|
||||
switch (type) {
|
||||
case "boolean":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
id: htmlID,
|
||||
type: "checkbox",
|
||||
checked: value,
|
||||
onchange: (event) => {
|
||||
const isChecked = event.target.checked;
|
||||
if (onChange !== undefined) {
|
||||
onChange(isChecked)
|
||||
}
|
||||
this.setSettingValue(id, isChecked);
|
||||
},
|
||||
}),
|
||||
]),
|
||||
])
|
||||
break;
|
||||
case "number":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
type,
|
||||
value,
|
||||
id: htmlID,
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
...attrs
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "slider":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("div", {
|
||||
style: {
|
||||
display: "grid",
|
||||
gridAutoFlow: "column",
|
||||
},
|
||||
}, [
|
||||
$el("input", {
|
||||
...attrs,
|
||||
value,
|
||||
type: "range",
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
e.target.nextElementSibling.value = e.target.value;
|
||||
},
|
||||
}),
|
||||
$el("input", {
|
||||
...attrs,
|
||||
value,
|
||||
id: htmlID,
|
||||
type: "number",
|
||||
style: {maxWidth: "4rem"},
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
e.target.previousElementSibling.value = e.target.value;
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "combo":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el(
|
||||
"select",
|
||||
{
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
},
|
||||
(typeof options === "function" ? options(value) : options || []).map((opt) => {
|
||||
if (typeof opt === "string") {
|
||||
opt = { text: opt };
|
||||
}
|
||||
const v = opt.value ?? opt.text;
|
||||
return $el("option", {
|
||||
value: v,
|
||||
textContent: opt.text,
|
||||
selected: value + "" === v + "",
|
||||
});
|
||||
})
|
||||
),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "text":
|
||||
default:
|
||||
if (type !== "text") {
|
||||
console.warn(`Unsupported setting type '${type}, defaulting to text`);
|
||||
}
|
||||
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
value,
|
||||
id: htmlID,
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
...attrs,
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (tooltip) {
|
||||
element.title = tooltip;
|
||||
}
|
||||
|
||||
return element;
|
||||
},
|
||||
});
|
||||
|
||||
const self = this;
|
||||
return {
|
||||
get value() {
|
||||
return self.getSettingValue(id, defaultValue);
|
||||
},
|
||||
set value(v) {
|
||||
self.setSettingValue(id, v);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
show() {
|
||||
this.textElement.replaceChildren(
|
||||
$el("tr", {
|
||||
style: {display: "none"},
|
||||
}, [
|
||||
$el("th"),
|
||||
$el("th", {style: {width: "33%"}})
|
||||
]),
|
||||
...this.settings.map((s) => s.render()),
|
||||
)
|
||||
this.element.showModal();
|
||||
}
|
||||
}
|
||||
|
||||
class ComfyList {
|
||||
#type;
|
||||
#text;
|
||||
@ -526,7 +288,7 @@ export class ComfyUI {
|
||||
constructor(app) {
|
||||
this.app = app;
|
||||
this.dialog = new ComfyDialog();
|
||||
this.settings = new ComfySettingsDialog();
|
||||
this.settings = new ComfySettingsDialog(app);
|
||||
|
||||
this.batchCount = 1;
|
||||
this.lastQueueSize = 0;
|
||||
@ -607,6 +369,31 @@ export class ComfyUI {
|
||||
},
|
||||
});
|
||||
|
||||
const autoQueueModeEl = toggleSwitch(
|
||||
"autoQueueMode",
|
||||
[
|
||||
{ text: "instant", tooltip: "A new prompt will be queued as soon as the queue reaches 0" },
|
||||
{ text: "change", tooltip: "A new prompt will be queued when the queue is at 0 and the graph is/has changed" },
|
||||
],
|
||||
{
|
||||
onChange: (value) => {
|
||||
this.autoQueueMode = value.item.value;
|
||||
},
|
||||
}
|
||||
);
|
||||
autoQueueModeEl.style.display = "none";
|
||||
|
||||
api.addEventListener("graphChanged", () => {
|
||||
if (this.autoQueueMode === "change" && this.autoQueueEnabled === true) {
|
||||
if (this.lastQueueSize === 0) {
|
||||
this.graphHasChanged = false;
|
||||
app.queuePrompt(0, this.batchCount);
|
||||
} else {
|
||||
this.graphHasChanged = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
|
||||
$el("div.drag-handle", {
|
||||
style: {
|
||||
@ -633,6 +420,7 @@ export class ComfyUI {
|
||||
document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
|
||||
this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
|
||||
document.getElementById("autoQueueCheckbox").checked = false;
|
||||
this.autoQueueEnabled = false;
|
||||
},
|
||||
}),
|
||||
]),
|
||||
@ -664,20 +452,22 @@ export class ComfyUI {
|
||||
},
|
||||
}),
|
||||
]),
|
||||
|
||||
$el("div",[
|
||||
$el("label",{
|
||||
for:"autoQueueCheckbox",
|
||||
innerHTML: "Auto Queue"
|
||||
// textContent: "Auto Queue"
|
||||
}),
|
||||
$el("input", {
|
||||
id: "autoQueueCheckbox",
|
||||
type: "checkbox",
|
||||
checked: false,
|
||||
title: "Automatically queue prompt when the queue size hits 0",
|
||||
|
||||
onchange: (e) => {
|
||||
this.autoQueueEnabled = e.target.checked;
|
||||
autoQueueModeEl.style.display = this.autoQueueEnabled ? "" : "none";
|
||||
}
|
||||
}),
|
||||
autoQueueModeEl
|
||||
])
|
||||
]),
|
||||
$el("div.comfy-menu-btns", [
|
||||
@ -811,10 +601,13 @@ export class ComfyUI {
|
||||
if (
|
||||
this.lastQueueSize != 0 &&
|
||||
status.exec_info.queue_remaining == 0 &&
|
||||
document.getElementById("autoQueueCheckbox").checked &&
|
||||
! app.lastExecutionError
|
||||
this.autoQueueEnabled &&
|
||||
(this.autoQueueMode === "instant" || this.graphHasChanged) &&
|
||||
!app.lastExecutionError
|
||||
) {
|
||||
app.queuePrompt(0, this.batchCount);
|
||||
status.exec_info.queue_remaining += this.batchCount;
|
||||
this.graphHasChanged = false;
|
||||
}
|
||||
this.lastQueueSize = status.exec_info.queue_remaining;
|
||||
}
|
||||
|
||||
32
comfy/web/scripts/ui/dialog.js
Normal file
32
comfy/web/scripts/ui/dialog.js
Normal file
@ -0,0 +1,32 @@
|
||||
import { $el } from "../ui.js";
|
||||
|
||||
export class ComfyDialog {
|
||||
constructor() {
|
||||
this.element = $el("div.comfy-modal", { parent: document.body }, [
|
||||
$el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]),
|
||||
]);
|
||||
}
|
||||
|
||||
createButtons() {
|
||||
return [
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
onclick: () => this.close(),
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
||||
close() {
|
||||
this.element.style.display = "none";
|
||||
}
|
||||
|
||||
show(html) {
|
||||
if (typeof html === "string") {
|
||||
this.textElement.innerHTML = html;
|
||||
} else {
|
||||
this.textElement.replaceChildren(html);
|
||||
}
|
||||
this.element.style.display = "flex";
|
||||
}
|
||||
}
|
||||
287
comfy/web/scripts/ui/draggableList.js
Normal file
287
comfy/web/scripts/ui/draggableList.js
Normal file
@ -0,0 +1,287 @@
|
||||
// @ts-check
|
||||
/*
|
||||
Original implementation:
|
||||
https://github.com/TahaSh/drag-to-reorder
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Taha Shashtari
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
*/
|
||||
|
||||
import { $el } from "../ui.js";
|
||||
|
||||
$el("style", {
|
||||
parent: document.head,
|
||||
textContent: `
|
||||
.draggable-item {
|
||||
position: relative;
|
||||
will-change: transform;
|
||||
user-select: none;
|
||||
}
|
||||
.draggable-item.is-idle {
|
||||
transition: 0.25s ease transform;
|
||||
}
|
||||
.draggable-item.is-draggable {
|
||||
z-index: 10;
|
||||
}
|
||||
`
|
||||
});
|
||||
|
||||
export class DraggableList extends EventTarget {
|
||||
listContainer;
|
||||
draggableItem;
|
||||
pointerStartX;
|
||||
pointerStartY;
|
||||
scrollYMax;
|
||||
itemsGap = 0;
|
||||
items = [];
|
||||
itemSelector;
|
||||
handleClass = "drag-handle";
|
||||
off = [];
|
||||
offDrag = [];
|
||||
|
||||
constructor(element, itemSelector) {
|
||||
super();
|
||||
this.listContainer = element;
|
||||
this.itemSelector = itemSelector;
|
||||
|
||||
if (!this.listContainer) return;
|
||||
|
||||
this.off.push(this.on(this.listContainer, "mousedown", this.dragStart));
|
||||
this.off.push(this.on(this.listContainer, "touchstart", this.dragStart));
|
||||
this.off.push(this.on(document, "mouseup", this.dragEnd));
|
||||
this.off.push(this.on(document, "touchend", this.dragEnd));
|
||||
}
|
||||
|
||||
getAllItems() {
|
||||
if (!this.items?.length) {
|
||||
this.items = Array.from(this.listContainer.querySelectorAll(this.itemSelector));
|
||||
this.items.forEach((element) => {
|
||||
element.classList.add("is-idle");
|
||||
});
|
||||
}
|
||||
return this.items;
|
||||
}
|
||||
|
||||
getIdleItems() {
|
||||
return this.getAllItems().filter((item) => item.classList.contains("is-idle"));
|
||||
}
|
||||
|
||||
isItemAbove(item) {
|
||||
return item.hasAttribute("data-is-above");
|
||||
}
|
||||
|
||||
isItemToggled(item) {
|
||||
return item.hasAttribute("data-is-toggled");
|
||||
}
|
||||
|
||||
on(source, event, listener, options) {
|
||||
listener = listener.bind(this);
|
||||
source.addEventListener(event, listener, options);
|
||||
return () => source.removeEventListener(event, listener);
|
||||
}
|
||||
|
||||
dragStart(e) {
|
||||
if (e.target.classList.contains(this.handleClass)) {
|
||||
this.draggableItem = e.target.closest(this.itemSelector);
|
||||
}
|
||||
|
||||
if (!this.draggableItem) return;
|
||||
|
||||
this.pointerStartX = e.clientX || e.touches[0].clientX;
|
||||
this.pointerStartY = e.clientY || e.touches[0].clientY;
|
||||
this.scrollYMax = this.listContainer.scrollHeight - this.listContainer.clientHeight;
|
||||
|
||||
this.setItemsGap();
|
||||
this.initDraggableItem();
|
||||
this.initItemsState();
|
||||
|
||||
this.offDrag.push(this.on(document, "mousemove", this.drag));
|
||||
this.offDrag.push(this.on(document, "touchmove", this.drag, { passive: false }));
|
||||
|
||||
this.dispatchEvent(
|
||||
new CustomEvent("dragstart", {
|
||||
detail: { element: this.draggableItem, position: this.getAllItems().indexOf(this.draggableItem) },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
setItemsGap() {
|
||||
if (this.getIdleItems().length <= 1) {
|
||||
this.itemsGap = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const item1 = this.getIdleItems()[0];
|
||||
const item2 = this.getIdleItems()[1];
|
||||
|
||||
const item1Rect = item1.getBoundingClientRect();
|
||||
const item2Rect = item2.getBoundingClientRect();
|
||||
|
||||
this.itemsGap = Math.abs(item1Rect.bottom - item2Rect.top);
|
||||
}
|
||||
|
||||
initItemsState() {
|
||||
this.getIdleItems().forEach((item, i) => {
|
||||
if (this.getAllItems().indexOf(this.draggableItem) > i) {
|
||||
item.dataset.isAbove = "";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
initDraggableItem() {
|
||||
this.draggableItem.classList.remove("is-idle");
|
||||
this.draggableItem.classList.add("is-draggable");
|
||||
}
|
||||
|
||||
drag(e) {
|
||||
if (!this.draggableItem) return;
|
||||
|
||||
e.preventDefault();
|
||||
|
||||
const clientX = e.clientX || e.touches[0].clientX;
|
||||
const clientY = e.clientY || e.touches[0].clientY;
|
||||
|
||||
const listRect = this.listContainer.getBoundingClientRect();
|
||||
|
||||
if (clientY > listRect.bottom) {
|
||||
if (this.listContainer.scrollTop < this.scrollYMax) {
|
||||
this.listContainer.scrollBy(0, 10);
|
||||
this.pointerStartY -= 10;
|
||||
}
|
||||
} else if (clientY < listRect.top && this.listContainer.scrollTop > 0) {
|
||||
this.pointerStartY += 10;
|
||||
this.listContainer.scrollBy(0, -10);
|
||||
}
|
||||
|
||||
const pointerOffsetX = clientX - this.pointerStartX;
|
||||
const pointerOffsetY = clientY - this.pointerStartY;
|
||||
|
||||
this.updateIdleItemsStateAndPosition();
|
||||
this.draggableItem.style.transform = `translate(${pointerOffsetX}px, ${pointerOffsetY}px)`;
|
||||
}
|
||||
|
||||
updateIdleItemsStateAndPosition() {
|
||||
const draggableItemRect = this.draggableItem.getBoundingClientRect();
|
||||
const draggableItemY = draggableItemRect.top + draggableItemRect.height / 2;
|
||||
|
||||
// Update state
|
||||
this.getIdleItems().forEach((item) => {
|
||||
const itemRect = item.getBoundingClientRect();
|
||||
const itemY = itemRect.top + itemRect.height / 2;
|
||||
if (this.isItemAbove(item)) {
|
||||
if (draggableItemY <= itemY) {
|
||||
item.dataset.isToggled = "";
|
||||
} else {
|
||||
delete item.dataset.isToggled;
|
||||
}
|
||||
} else {
|
||||
if (draggableItemY >= itemY) {
|
||||
item.dataset.isToggled = "";
|
||||
} else {
|
||||
delete item.dataset.isToggled;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Update position
|
||||
this.getIdleItems().forEach((item) => {
|
||||
if (this.isItemToggled(item)) {
|
||||
const direction = this.isItemAbove(item) ? 1 : -1;
|
||||
item.style.transform = `translateY(${direction * (draggableItemRect.height + this.itemsGap)}px)`;
|
||||
} else {
|
||||
item.style.transform = "";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
dragEnd() {
|
||||
if (!this.draggableItem) return;
|
||||
|
||||
this.applyNewItemsOrder();
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
applyNewItemsOrder() {
|
||||
const reorderedItems = [];
|
||||
|
||||
let oldPosition = -1;
|
||||
this.getAllItems().forEach((item, index) => {
|
||||
if (item === this.draggableItem) {
|
||||
oldPosition = index;
|
||||
return;
|
||||
}
|
||||
if (!this.isItemToggled(item)) {
|
||||
reorderedItems[index] = item;
|
||||
return;
|
||||
}
|
||||
const newIndex = this.isItemAbove(item) ? index + 1 : index - 1;
|
||||
reorderedItems[newIndex] = item;
|
||||
});
|
||||
|
||||
for (let index = 0; index < this.getAllItems().length; index++) {
|
||||
const item = reorderedItems[index];
|
||||
if (typeof item === "undefined") {
|
||||
reorderedItems[index] = this.draggableItem;
|
||||
}
|
||||
}
|
||||
|
||||
reorderedItems.forEach((item) => {
|
||||
this.listContainer.appendChild(item);
|
||||
});
|
||||
|
||||
this.items = reorderedItems;
|
||||
|
||||
this.dispatchEvent(
|
||||
new CustomEvent("dragend", {
|
||||
detail: { element: this.draggableItem, oldPosition, newPosition: reorderedItems.indexOf(this.draggableItem) },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
this.itemsGap = 0;
|
||||
this.items = [];
|
||||
this.unsetDraggableItem();
|
||||
this.unsetItemState();
|
||||
|
||||
this.offDrag.forEach((f) => f());
|
||||
this.offDrag = [];
|
||||
}
|
||||
|
||||
unsetDraggableItem() {
|
||||
this.draggableItem.style = null;
|
||||
this.draggableItem.classList.remove("is-draggable");
|
||||
this.draggableItem.classList.add("is-idle");
|
||||
this.draggableItem = null;
|
||||
}
|
||||
|
||||
unsetItemState() {
|
||||
this.getIdleItems().forEach((item, i) => {
|
||||
delete item.dataset.isAbove;
|
||||
delete item.dataset.isToggled;
|
||||
item.style.transform = "";
|
||||
});
|
||||
}
|
||||
|
||||
dispose() {
|
||||
this.off.forEach((f) => f());
|
||||
}
|
||||
}
|
||||
307
comfy/web/scripts/ui/settings.js
Normal file
307
comfy/web/scripts/ui/settings.js
Normal file
@ -0,0 +1,307 @@
|
||||
import { $el } from "../ui.js";
|
||||
import { api } from "../api.js";
|
||||
import { ComfyDialog } from "./dialog.js";
|
||||
|
||||
export class ComfySettingsDialog extends ComfyDialog {
|
||||
constructor(app) {
|
||||
super();
|
||||
this.app = app;
|
||||
this.settingsValues = {};
|
||||
this.settingsLookup = {};
|
||||
this.element = $el(
|
||||
"dialog",
|
||||
{
|
||||
id: "comfy-settings-dialog",
|
||||
parent: document.body,
|
||||
},
|
||||
[
|
||||
$el("table.comfy-modal-content.comfy-table", [
|
||||
$el("caption", { textContent: "Settings" }),
|
||||
$el("tbody", { $: (tbody) => (this.textElement = tbody) }),
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
style: {
|
||||
cursor: "pointer",
|
||||
},
|
||||
onclick: () => {
|
||||
this.element.close();
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
get settings() {
|
||||
return Object.values(this.settingsLookup);
|
||||
}
|
||||
|
||||
async load() {
|
||||
if (this.app.storageLocation === "browser") {
|
||||
this.settingsValues = localStorage;
|
||||
} else {
|
||||
this.settingsValues = await api.getSettings();
|
||||
}
|
||||
|
||||
// Trigger onChange for any settings added before load
|
||||
for (const id in this.settingsLookup) {
|
||||
this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]);
|
||||
}
|
||||
}
|
||||
|
||||
getId(id) {
|
||||
if (this.app.storageLocation === "browser") {
|
||||
id = "Comfy.Settings." + id;
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
getSettingValue(id, defaultValue) {
|
||||
let value = this.settingsValues[this.getId(id)];
|
||||
if(value != null) {
|
||||
if(this.app.storageLocation === "browser") {
|
||||
try {
|
||||
value = JSON.parse(value);
|
||||
} catch (error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return value ?? defaultValue;
|
||||
}
|
||||
|
||||
async setSettingValueAsync(id, value) {
|
||||
const json = JSON.stringify(value);
|
||||
localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage
|
||||
|
||||
let oldValue = this.getSettingValue(id, undefined);
|
||||
this.settingsValues[this.getId(id)] = value;
|
||||
|
||||
if (id in this.settingsLookup) {
|
||||
this.settingsLookup[id].onChange?.(value, oldValue);
|
||||
}
|
||||
|
||||
await api.storeSetting(id, value);
|
||||
}
|
||||
|
||||
setSettingValue(id, value) {
|
||||
this.setSettingValueAsync(id, value).catch((err) => {
|
||||
alert(`Error saving setting '${id}'`);
|
||||
console.error(err);
|
||||
});
|
||||
}
|
||||
|
||||
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) {
|
||||
if (!id) {
|
||||
throw new Error("Settings must have an ID");
|
||||
}
|
||||
|
||||
if (id in this.settingsLookup) {
|
||||
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
|
||||
}
|
||||
|
||||
let skipOnChange = false;
|
||||
let value = this.getSettingValue(id);
|
||||
if (value == null) {
|
||||
if (this.app.isNewUserSession) {
|
||||
// Check if we have a localStorage value but not a setting value and we are a new user
|
||||
const localValue = localStorage["Comfy.Settings." + id];
|
||||
if (localValue) {
|
||||
value = JSON.parse(localValue);
|
||||
this.setSettingValue(id, value); // Store on the server
|
||||
}
|
||||
}
|
||||
if (value == null) {
|
||||
value = defaultValue;
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger initial setting of value
|
||||
if (!skipOnChange) {
|
||||
onChange?.(value, undefined);
|
||||
}
|
||||
|
||||
this.settingsLookup[id] = {
|
||||
id,
|
||||
onChange,
|
||||
name,
|
||||
render: () => {
|
||||
const setter = (v) => {
|
||||
if (onChange) {
|
||||
onChange(v, value);
|
||||
}
|
||||
|
||||
this.setSettingValue(id, v);
|
||||
value = v;
|
||||
};
|
||||
value = this.getSettingValue(id, defaultValue);
|
||||
|
||||
let element;
|
||||
const htmlID = id.replaceAll(".", "-");
|
||||
|
||||
const labelCell = $el("td", [
|
||||
$el("label", {
|
||||
for: htmlID,
|
||||
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
|
||||
textContent: name,
|
||||
}),
|
||||
]);
|
||||
|
||||
if (typeof type === "function") {
|
||||
element = type(name, setter, value, attrs);
|
||||
} else {
|
||||
switch (type) {
|
||||
case "boolean":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
id: htmlID,
|
||||
type: "checkbox",
|
||||
checked: value,
|
||||
onchange: (event) => {
|
||||
const isChecked = event.target.checked;
|
||||
if (onChange !== undefined) {
|
||||
onChange(isChecked);
|
||||
}
|
||||
this.setSettingValue(id, isChecked);
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "number":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
type,
|
||||
value,
|
||||
id: htmlID,
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
...attrs,
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "slider":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el(
|
||||
"div",
|
||||
{
|
||||
style: {
|
||||
display: "grid",
|
||||
gridAutoFlow: "column",
|
||||
},
|
||||
},
|
||||
[
|
||||
$el("input", {
|
||||
...attrs,
|
||||
value,
|
||||
type: "range",
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
e.target.nextElementSibling.value = e.target.value;
|
||||
},
|
||||
}),
|
||||
$el("input", {
|
||||
...attrs,
|
||||
value,
|
||||
id: htmlID,
|
||||
type: "number",
|
||||
style: { maxWidth: "4rem" },
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
e.target.previousElementSibling.value = e.target.value;
|
||||
},
|
||||
}),
|
||||
]
|
||||
),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "combo":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el(
|
||||
"select",
|
||||
{
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
},
|
||||
(typeof options === "function" ? options(value) : options || []).map((opt) => {
|
||||
if (typeof opt === "string") {
|
||||
opt = { text: opt };
|
||||
}
|
||||
const v = opt.value ?? opt.text;
|
||||
return $el("option", {
|
||||
value: v,
|
||||
textContent: opt.text,
|
||||
selected: value + "" === v + "",
|
||||
});
|
||||
})
|
||||
),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "text":
|
||||
default:
|
||||
if (type !== "text") {
|
||||
console.warn(`Unsupported setting type '${type}, defaulting to text`);
|
||||
}
|
||||
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
value,
|
||||
id: htmlID,
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
...attrs,
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (tooltip) {
|
||||
element.title = tooltip;
|
||||
}
|
||||
|
||||
return element;
|
||||
},
|
||||
};
|
||||
|
||||
const self = this;
|
||||
return {
|
||||
get value() {
|
||||
return self.getSettingValue(id, defaultValue);
|
||||
},
|
||||
set value(v) {
|
||||
self.setSettingValue(id, v);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
show() {
|
||||
this.textElement.replaceChildren(
|
||||
$el(
|
||||
"tr",
|
||||
{
|
||||
style: { display: "none" },
|
||||
},
|
||||
[$el("th"), $el("th", { style: { width: "33%" } })]
|
||||
),
|
||||
...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render())
|
||||
);
|
||||
this.element.showModal();
|
||||
}
|
||||
}
|
||||
34
comfy/web/scripts/ui/spinner.css
Normal file
34
comfy/web/scripts/ui/spinner.css
Normal file
@ -0,0 +1,34 @@
|
||||
.lds-ring {
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
width: 1em;
|
||||
height: 1em;
|
||||
}
|
||||
.lds-ring div {
|
||||
box-sizing: border-box;
|
||||
display: block;
|
||||
position: absolute;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border: 0.15em solid #fff;
|
||||
border-radius: 50%;
|
||||
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
|
||||
border-color: #fff transparent transparent transparent;
|
||||
}
|
||||
.lds-ring div:nth-child(1) {
|
||||
animation-delay: -0.45s;
|
||||
}
|
||||
.lds-ring div:nth-child(2) {
|
||||
animation-delay: -0.3s;
|
||||
}
|
||||
.lds-ring div:nth-child(3) {
|
||||
animation-delay: -0.15s;
|
||||
}
|
||||
@keyframes lds-ring {
|
||||
0% {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
9
comfy/web/scripts/ui/spinner.js
Normal file
9
comfy/web/scripts/ui/spinner.js
Normal file
@ -0,0 +1,9 @@
|
||||
import { addStylesheet } from "../utils.js";
|
||||
|
||||
addStylesheet(import.meta.url);
|
||||
|
||||
export function createSpinner() {
|
||||
const div = document.createElement("div");
|
||||
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
|
||||
return div.firstElementChild;
|
||||
}
|
||||
60
comfy/web/scripts/ui/toggleSwitch.js
Normal file
60
comfy/web/scripts/ui/toggleSwitch.js
Normal file
@ -0,0 +1,60 @@
|
||||
import { $el } from "../ui.js";
|
||||
|
||||
/**
|
||||
* @typedef { { text: string, value?: string, tooltip?: string } } ToggleSwitchItem
|
||||
*/
|
||||
/**
|
||||
* Creates a toggle switch element
|
||||
* @param { string } name
|
||||
* @param { Array<string | ToggleSwitchItem } items
|
||||
* @param { Object } [opts]
|
||||
* @param { (e: { item: ToggleSwitchItem, prev?: ToggleSwitchItem }) => void } [opts.onChange]
|
||||
*/
|
||||
export function toggleSwitch(name, items, { onChange } = {}) {
|
||||
let selectedIndex;
|
||||
let elements;
|
||||
|
||||
function updateSelected(index) {
|
||||
if (selectedIndex != null) {
|
||||
elements[selectedIndex].classList.remove("comfy-toggle-selected");
|
||||
}
|
||||
onChange?.({ item: items[index], prev: selectedIndex == null ? undefined : items[selectedIndex] });
|
||||
selectedIndex = index;
|
||||
elements[selectedIndex].classList.add("comfy-toggle-selected");
|
||||
}
|
||||
|
||||
elements = items.map((item, i) => {
|
||||
if (typeof item === "string") item = { text: item };
|
||||
if (!item.value) item.value = item.text;
|
||||
|
||||
const toggle = $el(
|
||||
"label",
|
||||
{
|
||||
textContent: item.text,
|
||||
title: item.tooltip ?? "",
|
||||
},
|
||||
$el("input", {
|
||||
name,
|
||||
type: "radio",
|
||||
value: item.value ?? item.text,
|
||||
checked: item.selected,
|
||||
onchange: () => {
|
||||
updateSelected(i);
|
||||
},
|
||||
})
|
||||
);
|
||||
if (item.selected) {
|
||||
updateSelected(i);
|
||||
}
|
||||
return toggle;
|
||||
});
|
||||
|
||||
const container = $el("div.comfy-toggle-switch", elements);
|
||||
|
||||
if (selectedIndex == null) {
|
||||
elements[0].children[0].checked = true;
|
||||
updateSelected(0);
|
||||
}
|
||||
|
||||
return container;
|
||||
}
|
||||
135
comfy/web/scripts/ui/userSelection.css
Normal file
135
comfy/web/scripts/ui/userSelection.css
Normal file
@ -0,0 +1,135 @@
|
||||
.comfy-user-selection {
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
z-index: 999;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-family: sans-serif;
|
||||
background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color));
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner {
|
||||
background: var(--comfy-menu-bg);
|
||||
margin-top: -30vh;
|
||||
padding: 20px 40px;
|
||||
border-radius: 10px;
|
||||
min-width: 365px;
|
||||
position: relative;
|
||||
box-shadow: 0 0 20px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner form {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner h1 {
|
||||
margin: 10px 0 30px 0;
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner label {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.comfy-user-selection input,
|
||||
.comfy-user-selection select {
|
||||
background-color: var(--comfy-input-bg);
|
||||
color: var(--input-text);
|
||||
border: 0;
|
||||
border-radius: 5px;
|
||||
padding: 5px;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.comfy-user-selection input::placeholder {
|
||||
color: var(--descrip-text);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.comfy-user-existing {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.no-users .comfy-user-existing {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner .or-separator {
|
||||
margin: 10px 0;
|
||||
padding: 10px;
|
||||
display: block;
|
||||
text-align: center;
|
||||
width: 100%;
|
||||
color: var(--descrip-text);
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner .or-separator {
|
||||
overflow: hidden;
|
||||
text-align: center;
|
||||
margin-left: -10px;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner .or-separator::before,
|
||||
.comfy-user-selection-inner .or-separator::after {
|
||||
content: "";
|
||||
background-color: var(--border-color);
|
||||
position: relative;
|
||||
height: 1px;
|
||||
vertical-align: middle;
|
||||
display: inline-block;
|
||||
width: calc(50% - 20px);
|
||||
top: -1px;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner .or-separator::before {
|
||||
right: 10px;
|
||||
margin-left: -50%;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner .or-separator::after {
|
||||
left: 10px;
|
||||
margin-right: -50%;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner section {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
margin: -10px;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner section.selected {
|
||||
background: var(--border-color);
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner footer {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.comfy-user-selection-inner .comfy-user-error {
|
||||
color: var(--error-text);
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.comfy-user-button-next {
|
||||
font-size: 16px;
|
||||
padding: 6px 10px;
|
||||
width: 100px;
|
||||
display: flex;
|
||||
gap: 5px;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
114
comfy/web/scripts/ui/userSelection.js
Normal file
114
comfy/web/scripts/ui/userSelection.js
Normal file
@ -0,0 +1,114 @@
|
||||
import { api } from "../api.js";
|
||||
import { $el } from "../ui.js";
|
||||
import { addStylesheet } from "../utils.js";
|
||||
import { createSpinner } from "./spinner.js";
|
||||
|
||||
export class UserSelectionScreen {
|
||||
async show(users, user) {
|
||||
// This will rarely be hit so move the loading to on demand
|
||||
await addStylesheet(import.meta.url);
|
||||
const userSelection = document.getElementById("comfy-user-selection");
|
||||
userSelection.style.display = "";
|
||||
return new Promise((resolve) => {
|
||||
const input = userSelection.getElementsByTagName("input")[0];
|
||||
const select = userSelection.getElementsByTagName("select")[0];
|
||||
const inputSection = input.closest("section");
|
||||
const selectSection = select.closest("section");
|
||||
const form = userSelection.getElementsByTagName("form")[0];
|
||||
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
|
||||
const button = userSelection.getElementsByClassName("comfy-user-button-next")[0];
|
||||
|
||||
let inputActive = null;
|
||||
input.addEventListener("focus", () => {
|
||||
inputSection.classList.add("selected");
|
||||
selectSection.classList.remove("selected");
|
||||
inputActive = true;
|
||||
});
|
||||
select.addEventListener("focus", () => {
|
||||
inputSection.classList.remove("selected");
|
||||
selectSection.classList.add("selected");
|
||||
inputActive = false;
|
||||
select.style.color = "";
|
||||
});
|
||||
select.addEventListener("blur", () => {
|
||||
if (!select.value) {
|
||||
select.style.color = "var(--descrip-text)";
|
||||
}
|
||||
});
|
||||
|
||||
form.addEventListener("submit", async (e) => {
|
||||
e.preventDefault();
|
||||
if (inputActive == null) {
|
||||
error.textContent = "Please enter a username or select an existing user.";
|
||||
} else if (inputActive) {
|
||||
const username = input.value.trim();
|
||||
if (!username) {
|
||||
error.textContent = "Please enter a username.";
|
||||
return;
|
||||
}
|
||||
|
||||
// Create new user
|
||||
input.disabled = select.disabled = input.readonly = select.readonly = true;
|
||||
const spinner = createSpinner();
|
||||
button.prepend(spinner);
|
||||
try {
|
||||
const resp = await api.createUser(username);
|
||||
if (resp.status >= 300) {
|
||||
let message = "Error creating user: " + resp.status + " " + resp.statusText;
|
||||
try {
|
||||
const res = await resp.json();
|
||||
if(res.error) {
|
||||
message = res.error;
|
||||
}
|
||||
} catch (error) {
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
resolve({ username, userId: await resp.json(), created: true });
|
||||
} catch (err) {
|
||||
spinner.remove();
|
||||
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
|
||||
input.disabled = select.disabled = input.readonly = select.readonly = false;
|
||||
return;
|
||||
}
|
||||
} else if (!select.value) {
|
||||
error.textContent = "Please select an existing user.";
|
||||
return;
|
||||
} else {
|
||||
resolve({ username: users[select.value], userId: select.value, created: false });
|
||||
}
|
||||
});
|
||||
|
||||
if (user) {
|
||||
const name = localStorage["Comfy.userName"];
|
||||
if (name) {
|
||||
input.value = name;
|
||||
}
|
||||
}
|
||||
if (input.value) {
|
||||
// Focus the input, do this separately as sometimes browsers like to fill in the value
|
||||
input.focus();
|
||||
}
|
||||
|
||||
const userIds = Object.keys(users ?? {});
|
||||
if (userIds.length) {
|
||||
for (const u of userIds) {
|
||||
$el("option", { textContent: users[u], value: u, parent: select });
|
||||
}
|
||||
select.style.color = "var(--descrip-text)";
|
||||
|
||||
if (select.value) {
|
||||
// Focus the select, do this separately as sometimes browsers like to fill in the value
|
||||
select.focus();
|
||||
}
|
||||
} else {
|
||||
userSelection.classList.add("no-users");
|
||||
input.focus();
|
||||
}
|
||||
}).then((r) => {
|
||||
userSelection.remove();
|
||||
return r;
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,5 @@
|
||||
import { $el } from "./ui.js";
|
||||
|
||||
// Simple date formatter
|
||||
const parts = {
|
||||
d: (d) => d.getDate(),
|
||||
@ -65,3 +67,22 @@ export function applyTextReplacements(app, value) {
|
||||
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
|
||||
});
|
||||
}
|
||||
|
||||
export async function addStylesheet(urlOrFile, relativeTo) {
|
||||
return new Promise((res, rej) => {
|
||||
let url;
|
||||
if (urlOrFile.endsWith(".js")) {
|
||||
url = urlOrFile.substr(0, urlOrFile.length - 2) + "css";
|
||||
} else {
|
||||
url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString();
|
||||
}
|
||||
$el("link", {
|
||||
parent: document.head,
|
||||
rel: "stylesheet",
|
||||
type: "text/css",
|
||||
href: url,
|
||||
onload: res,
|
||||
onerror: rej,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -1,6 +1,19 @@
|
||||
import { api } from "./api.js"
|
||||
import "./domWidget.js";
|
||||
|
||||
let controlValueRunBefore = false;
|
||||
export function updateControlWidgetLabel(widget) {
|
||||
let replacement = "after";
|
||||
let find = "before";
|
||||
if (controlValueRunBefore) {
|
||||
[find, replacement] = [replacement, find]
|
||||
}
|
||||
widget.label = (widget.label ?? widget.name).replace(find, replacement);
|
||||
}
|
||||
|
||||
const IS_CONTROL_WIDGET = Symbol();
|
||||
const HAS_EXECUTED = Symbol();
|
||||
|
||||
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
|
||||
let defaultVal = inputData[1]["default"];
|
||||
let { min, max, step, round} = inputData[1];
|
||||
@ -62,6 +75,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
|
||||
serialize: false, // Don't include this in prompt.
|
||||
}
|
||||
);
|
||||
valueControl[IS_CONTROL_WIDGET] = true;
|
||||
updateControlWidgetLabel(valueControl);
|
||||
widgets.push(valueControl);
|
||||
|
||||
const isCombo = targetWidget.type === "combo";
|
||||
@ -76,10 +91,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
|
||||
serialize: false, // Don't include this in prompt.
|
||||
}
|
||||
);
|
||||
updateControlWidgetLabel(comboFilter);
|
||||
|
||||
widgets.push(comboFilter);
|
||||
}
|
||||
|
||||
valueControl.afterQueued = () => {
|
||||
const applyWidgetControl = () => {
|
||||
var v = valueControl.value;
|
||||
|
||||
if (isCombo && v !== "fixed") {
|
||||
@ -159,6 +176,23 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
|
||||
targetWidget.callback(targetWidget.value);
|
||||
}
|
||||
};
|
||||
|
||||
valueControl.beforeQueued = () => {
|
||||
if (controlValueRunBefore) {
|
||||
// Don't run on first execution
|
||||
if (valueControl[HAS_EXECUTED]) {
|
||||
applyWidgetControl();
|
||||
}
|
||||
}
|
||||
valueControl[HAS_EXECUTED] = true;
|
||||
};
|
||||
|
||||
valueControl.afterQueued = () => {
|
||||
if (!controlValueRunBefore) {
|
||||
applyWidgetControl();
|
||||
}
|
||||
};
|
||||
|
||||
return widgets;
|
||||
};
|
||||
|
||||
@ -224,6 +258,34 @@ function isSlider(display, app) {
|
||||
return (display==="slider") ? "slider" : "number"
|
||||
}
|
||||
|
||||
export function initWidgets(app) {
|
||||
app.ui.settings.addSetting({
|
||||
id: "Comfy.WidgetControlMode",
|
||||
name: "Widget Value Control Mode",
|
||||
type: "combo",
|
||||
defaultValue: "after",
|
||||
options: ["before", "after"],
|
||||
tooltip: "Controls when widget values are updated (randomize/increment/decrement), either before the prompt is queued or after.",
|
||||
onChange(value) {
|
||||
controlValueRunBefore = value === "before";
|
||||
for (const n of app.graph._nodes) {
|
||||
if (!n.widgets) continue;
|
||||
for (const w of n.widgets) {
|
||||
if (w[IS_CONTROL_WIDGET]) {
|
||||
updateControlWidgetLabel(w);
|
||||
if (w.linkedWidgets) {
|
||||
for (const l of w.linkedWidgets) {
|
||||
updateControlWidgetLabel(l);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
app.graph.setDirtyCanvas(true);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export const ComfyWidgets = {
|
||||
"INT:seed": seedWidget,
|
||||
"INT:noise_seed": seedWidget,
|
||||
|
||||
@ -121,6 +121,8 @@ body {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch,
|
||||
.comfy-btn,
|
||||
.comfy-menu > button,
|
||||
.comfy-menu-btns button,
|
||||
.comfy-menu .comfy-list button,
|
||||
@ -133,6 +135,7 @@ body {
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.comfy-btn:hover:not(:disabled),
|
||||
.comfy-menu > button:hover,
|
||||
.comfy-menu-btns button:hover,
|
||||
.comfy-menu .comfy-list button:hover,
|
||||
@ -143,6 +146,12 @@ body {
|
||||
}
|
||||
|
||||
.comfy-menu span.drag-handle {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
span.drag-handle {
|
||||
width: 10px;
|
||||
height: 20px;
|
||||
display: inline-block;
|
||||
@ -158,12 +167,9 @@ body {
|
||||
letter-spacing: 2px;
|
||||
color: var(--drag-text);
|
||||
text-shadow: 1px 0 1px black;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.comfy-menu span.drag-handle::after {
|
||||
span.drag-handle::after {
|
||||
content: '.. .. ..';
|
||||
}
|
||||
|
||||
@ -429,6 +435,43 @@ dialog::backdrop {
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch {
|
||||
border-width: 2px;
|
||||
display: flex;
|
||||
background-color: var(--comfy-input-bg);
|
||||
margin: 2px 0;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch label {
|
||||
padding: 2px 0px 3px 6px;
|
||||
flex: auto;
|
||||
border-radius: 8px;
|
||||
align-items: center;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch label:first-child {
|
||||
border-top-left-radius: 8px;
|
||||
border-bottom-left-radius: 8px;
|
||||
}
|
||||
.comfy-toggle-switch label:last-child {
|
||||
border-top-right-radius: 8px;
|
||||
border-bottom-right-radius: 8px;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch .comfy-toggle-selected {
|
||||
background-color: var(--comfy-menu-bg);
|
||||
}
|
||||
|
||||
#extraOptions {
|
||||
padding: 4px;
|
||||
background-color: var(--bg-color);
|
||||
margin-bottom: 4px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Search box */
|
||||
|
||||
.litegraph.litesearchbox {
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from comfy import samplers
|
||||
from comfy import model_management
|
||||
from comfy import sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.cmd import latent_preview
|
||||
@ -26,9 +27,8 @@ class BasicScheduler:
|
||||
if denoise < 1.0:
|
||||
total_steps = int(steps/denoise)
|
||||
|
||||
inner_model = model.patch_model(patch_weights=False)
|
||||
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
|
||||
model.unpatch_model()
|
||||
model_management.load_models_gpu([model])
|
||||
sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
|
||||
sigmas = sigmas[-(steps + 1):]
|
||||
return (sigmas, )
|
||||
|
||||
@ -106,9 +106,8 @@ class SDTurboScheduler:
|
||||
def get_sigmas(self, model, steps, denoise):
|
||||
start_step = 10 - int(10 * denoise)
|
||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
||||
inner_model = model.patch_model(patch_weights=False)
|
||||
sigmas = inner_model.model_sampling.sigma(timesteps)
|
||||
model.unpatch_model()
|
||||
model_management.load_models_gpu([model])
|
||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return (sigmas, )
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ class FreeU:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
@ -73,7 +73,7 @@ class FreeU_V2:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
|
||||
@ -32,29 +32,29 @@ class HyperTile:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches"
|
||||
|
||||
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
|
||||
apply_to = set()
|
||||
temp = model_channels
|
||||
for x in range(max_depth + 1):
|
||||
apply_to.add(temp)
|
||||
temp *= 2
|
||||
|
||||
latent_tile_size = max(32, tile_size) // 8
|
||||
self.temp = None
|
||||
|
||||
def hypertile_in(q, k, v, extra_options):
|
||||
if q.shape[-1] in apply_to:
|
||||
model_chans = q.shape[-2]
|
||||
orig_shape = extra_options['original_shape']
|
||||
apply_to = []
|
||||
for i in range(max_depth + 1):
|
||||
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
|
||||
|
||||
if model_chans in apply_to:
|
||||
shape = extra_options["original_shape"]
|
||||
aspect_ratio = shape[-1] / shape[-2]
|
||||
|
||||
hw = q.size(1)
|
||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||
|
||||
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
||||
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
|
||||
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||
|
||||
|
||||
@ -122,10 +122,34 @@ class LatentBatch:
|
||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBatchSeedBehavior:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"seed_behavior": (["random", "fixed"],),}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples, seed_behavior):
|
||||
samples_out = samples.copy()
|
||||
latent = samples["samples"]
|
||||
if seed_behavior == "random":
|
||||
if 'batch_index' in samples_out:
|
||||
samples_out.pop('batch_index')
|
||||
elif seed_behavior == "fixed":
|
||||
batch_number = samples_out.get("batch_index", [0])[0]
|
||||
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
||||
|
||||
return (samples_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
"LatentBatch": LatentBatch,
|
||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
from comfy import utils
|
||||
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
@ -11,7 +11,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
if resize_source:
|
||||
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||
|
||||
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
|
||||
source = utils.repeat_to_batch_size(source, destination.shape[0])
|
||||
|
||||
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
|
||||
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
|
||||
@ -24,7 +24,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
else:
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
||||
mask = utils.repeat_to_batch_size(mask, source.shape[0])
|
||||
|
||||
# calculate the bounds of the source that will be overlapping the destination
|
||||
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
from comfy import utils
|
||||
|
||||
class PatchModelAddDownscale:
|
||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
@ -27,12 +27,12 @@ class PatchModelAddDownscale:
|
||||
if transformer_options["block"][1] == block_number:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
h = utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
h = utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from comfy import sd
|
||||
from comfy import sd, utils
|
||||
from comfy import model_base
|
||||
import comfy.model_management
|
||||
from comfy import model_management
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
import json
|
||||
@ -118,6 +118,48 @@ class ModelMergeBlocks:
|
||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||
return (m, )
|
||||
|
||||
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {}
|
||||
|
||||
enable_modelspec = True
|
||||
if isinstance(model.model, model_base.SDXL):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
|
||||
elif isinstance(model.model, model_base.SDXLRefiner):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
|
||||
else:
|
||||
enable_modelspec = False
|
||||
|
||||
if enable_modelspec:
|
||||
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||
metadata["modelspec.implementation"] = "sgm"
|
||||
metadata["modelspec.title"] = "{} {}".format(filename, counter)
|
||||
|
||||
#TODO:
|
||||
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
|
||||
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||
# "v2-inpainting"
|
||||
|
||||
if model.model.model_type == model_base.ModelType.EPS:
|
||||
metadata["modelspec.predict_key"] = "epsilon"
|
||||
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
metadata["modelspec.predict_key"] = "v"
|
||||
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
|
||||
|
||||
class CheckpointSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -136,46 +178,7 @@ class CheckpointSave:
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {}
|
||||
|
||||
enable_modelspec = True
|
||||
if isinstance(model.model, model_base.SDXL):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
|
||||
elif isinstance(model.model, model_base.SDXLRefiner):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
|
||||
else:
|
||||
enable_modelspec = False
|
||||
|
||||
if enable_modelspec:
|
||||
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||
metadata["modelspec.implementation"] = "sgm"
|
||||
metadata["modelspec.title"] = "{} {}".format(filename, counter)
|
||||
|
||||
#TODO:
|
||||
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
|
||||
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||
# "v2-inpainting"
|
||||
|
||||
if model.model.model_type == model_base.ModelType.EPS:
|
||||
metadata["modelspec.predict_key"] = "epsilon"
|
||||
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
metadata["modelspec.predict_key"] = "v"
|
||||
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
|
||||
class CLIPSave:
|
||||
@ -205,7 +208,7 @@ class CLIPSave:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()])
|
||||
model_management.load_models_gpu([clip.load_model()])
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||
@ -229,9 +232,9 @@ class CLIPSave:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||
current_clip_sd = utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||
|
||||
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||
utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class VAESave:
|
||||
@ -265,7 +268,7 @@ class VAESave:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
187
comfy_extras/nodes/nodes_photomaker.py
Normal file
187
comfy_extras/nodes/nodes_photomaker.py
Normal file
@ -0,0 +1,187 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy import clip_model
|
||||
from comfy import clip_vision
|
||||
from comfy import ops
|
||||
|
||||
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
|
||||
VISION_CONFIG_DICT = {
|
||||
"hidden_size": 1024,
|
||||
"image_size": 224,
|
||||
"intermediate_size": 4096,
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"hidden_act": "quick_gelu",
|
||||
}
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=ops):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = operations.LayerNorm(in_dim)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim)
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class FuseModule(nn.Module):
|
||||
def __init__(self, embed_dim, operations):
|
||||
super().__init__()
|
||||
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
|
||||
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
|
||||
self.layer_norm = operations.LayerNorm(embed_dim)
|
||||
|
||||
def fuse_fn(self, prompt_embeds, id_embeds):
|
||||
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
||||
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
||||
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
||||
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
||||
return stacked_id_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompt_embeds,
|
||||
id_embeds,
|
||||
class_tokens_mask,
|
||||
) -> torch.Tensor:
|
||||
# id_embeds shape: [b, max_num_inputs, 1, 2048]
|
||||
id_embeds = id_embeds.to(prompt_embeds.dtype)
|
||||
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
|
||||
batch_size, max_num_inputs = id_embeds.shape[:2]
|
||||
# seq_length: 77
|
||||
seq_length = prompt_embeds.shape[1]
|
||||
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
|
||||
flat_id_embeds = id_embeds.view(
|
||||
-1, id_embeds.shape[-2], id_embeds.shape[-1]
|
||||
)
|
||||
# valid_id_mask [b*max_num_inputs]
|
||||
valid_id_mask = (
|
||||
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
|
||||
< num_inputs[:, None]
|
||||
)
|
||||
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
|
||||
|
||||
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
|
||||
class_tokens_mask = class_tokens_mask.view(-1)
|
||||
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
|
||||
# slice out the image token embeddings
|
||||
image_token_embeds = prompt_embeds[class_tokens_mask]
|
||||
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
|
||||
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
||||
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
||||
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
|
||||
return updated_prompt_embeds
|
||||
|
||||
class PhotoMakerIDEncoder(clip_model.CLIPVisionModelProjection):
|
||||
def __init__(self):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
|
||||
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, ops.manual_cast)
|
||||
self.visual_projection_2 = ops.manual_cast.Linear(1024, 1280, bias=False)
|
||||
self.fuse_module = FuseModule(2048, ops.manual_cast)
|
||||
|
||||
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
shared_id_embeds = self.vision_model(id_pixel_values)[2]
|
||||
id_embeds = self.visual_projection(shared_id_embeds)
|
||||
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
|
||||
|
||||
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
|
||||
|
||||
return updated_prompt_embeds
|
||||
|
||||
|
||||
class PhotoMakerLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
|
||||
|
||||
RETURN_TYPES = ("PHOTOMAKER",)
|
||||
FUNCTION = "load_photomaker_model"
|
||||
|
||||
CATEGORY = "_for_testing/photomaker"
|
||||
|
||||
def load_photomaker_model(self, photomaker_model_name):
|
||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
||||
photomaker_model = PhotoMakerIDEncoder()
|
||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||
if "id_encoder" in data:
|
||||
data = data["id_encoder"]
|
||||
photomaker_model.load_state_dict(data)
|
||||
return (photomaker_model,)
|
||||
|
||||
|
||||
class PhotoMakerEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "photomaker": ("PHOTOMAKER",),
|
||||
"image": ("IMAGE",),
|
||||
"clip": ("CLIP", ),
|
||||
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_photomaker"
|
||||
|
||||
CATEGORY = "_for_testing/photomaker"
|
||||
|
||||
def apply_photomaker(self, photomaker, image, clip, text):
|
||||
special_token = "photomaker"
|
||||
pixel_values = clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
|
||||
try:
|
||||
index = text.split(" ").index(special_token) + 1
|
||||
except ValueError:
|
||||
index = -1
|
||||
tokens = clip.tokenize(text, return_word_ids=True)
|
||||
out_tokens = {}
|
||||
for k in tokens:
|
||||
out_tokens[k] = []
|
||||
for t in tokens[k]:
|
||||
f = list(filter(lambda x: x[2] != index, t))
|
||||
while len(f) < len(t):
|
||||
f.append(t[-1])
|
||||
out_tokens[k].append(f)
|
||||
|
||||
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
|
||||
|
||||
if index > 0:
|
||||
token_index = index - 1
|
||||
num_id_images = 1
|
||||
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
|
||||
out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
|
||||
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
|
||||
else:
|
||||
out = cond
|
||||
|
||||
return ([[out, {"pooled_output": pooled}]], )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PhotoMakerLoader": PhotoMakerLoader,
|
||||
"PhotoMakerEncode": PhotoMakerEncode,
|
||||
}
|
||||
|
||||
@ -33,6 +33,7 @@ class Blend:
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||
image2 = image2.to(image1.device)
|
||||
if image1.shape != image2.shape:
|
||||
image2 = image2.permute(0, 3, 1, 2)
|
||||
image2 = utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
|
||||
|
||||
@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
ratio = math.ceil(math.sqrt(lh * lw / hw1))
|
||||
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
# Reshape
|
||||
@ -143,6 +143,8 @@ class SelfAttentionGuidance:
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"]
|
||||
x = args["input"]
|
||||
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
|
||||
return cfg_result
|
||||
|
||||
# create the adversarially blurred image
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy import utils
|
||||
@ -49,13 +48,57 @@ class StableZero123_Conditioning:
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
t = vae.encode(encode_pixels)
|
||||
cam_embeds = camera_embeddings(elevation, azimuth)
|
||||
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
|
||||
cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
|
||||
|
||||
positive = [[cond, {"concat_latent_image": t}]]
|
||||
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
|
||||
class StableZero123_Conditioning_Batched:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/3d_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
cam_embeds = []
|
||||
for i in range(batch_size):
|
||||
cam_embeds.append(camera_embeddings(elevation, azimuth))
|
||||
elevation += elevation_batch_increment
|
||||
azimuth += azimuth_batch_increment
|
||||
|
||||
cam_embeds = torch.cat(cam_embeds, dim=0)
|
||||
cond = torch.cat([utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
|
||||
|
||||
positive = [[cond, {"concat_latent_image": t}]]
|
||||
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableZero123_Conditioning": StableZero123_Conditioning,
|
||||
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
from comfy.cmd import folder_paths
|
||||
from . import nodes_model_merging
|
||||
|
||||
|
||||
class ImageOnlyCheckpointLoader:
|
||||
@ -78,10 +79,26 @@ class VideoLinearCFGGuidance:
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave):
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip_vision": ("CLIP_VISION",),
|
||||
"vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
|
||||
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
0
models/photomaker/put_photomaker_models_here
Normal file
0
models/photomaker/put_photomaker_models_here
Normal file
@ -1,3 +1,4 @@
|
||||
{
|
||||
"presets": ["@babel/preset-env"]
|
||||
"presets": ["@babel/preset-env"],
|
||||
"plugins": ["babel-plugin-transform-import-meta"]
|
||||
}
|
||||
|
||||
20
tests-ui/package-lock.json
generated
20
tests-ui/package-lock.json
generated
@ -11,6 +11,7 @@
|
||||
"devDependencies": {
|
||||
"@babel/preset-env": "^7.22.20",
|
||||
"@types/jest": "^29.5.5",
|
||||
"babel-plugin-transform-import-meta": "^2.2.1",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0"
|
||||
}
|
||||
@ -2591,6 +2592,19 @@
|
||||
"@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/babel-plugin-transform-import-meta": {
|
||||
"version": "2.2.1",
|
||||
"resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz",
|
||||
"integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@babel/template": "^7.4.4",
|
||||
"tslib": "^2.4.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@babel/core": "^7.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/babel-preset-current-node-syntax": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz",
|
||||
@ -5233,6 +5247,12 @@
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/tslib": {
|
||||
"version": "2.6.2",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
|
||||
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/type-detect": {
|
||||
"version": "4.0.8",
|
||||
"resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz",
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
"devDependencies": {
|
||||
"@babel/preset-env": "^7.22.20",
|
||||
"@types/jest": "^29.5.5",
|
||||
"babel-plugin-transform-import-meta": "^2.2.1",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0"
|
||||
}
|
||||
|
||||
295
tests-ui/tests/users.test.js
Normal file
295
tests-ui/tests/users.test.js
Normal file
@ -0,0 +1,295 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
const { start } = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
describe("users", () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global);
|
||||
});
|
||||
|
||||
function expectNoUserScreen() {
|
||||
// Ensure login isnt visible
|
||||
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
|
||||
expect(selection["style"].display).toBe("none");
|
||||
const menu = document.querySelectorAll(".comfy-menu")?.[0];
|
||||
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
|
||||
}
|
||||
|
||||
describe("multi-user", () => {
|
||||
function mockAddStylesheet() {
|
||||
const utils = require("../../comfy/web/scripts/utils");
|
||||
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
|
||||
}
|
||||
|
||||
async function waitForUserScreenShow() {
|
||||
mockAddStylesheet();
|
||||
|
||||
// Wait for "show" to be called
|
||||
const { UserSelectionScreen } = require("../../comfy/web/scripts/ui/userSelection");
|
||||
let resolve, reject;
|
||||
const fn = UserSelectionScreen.prototype.show;
|
||||
const p = new Promise((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
|
||||
const res = fn(...args);
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
resolve();
|
||||
return res;
|
||||
});
|
||||
// @ts-ignore
|
||||
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
|
||||
await p;
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
}
|
||||
|
||||
async function testUserScreen(onShown, users) {
|
||||
if (!users) {
|
||||
users = {};
|
||||
}
|
||||
const starting = start({
|
||||
resetEnv: true,
|
||||
userConfig: { storage: "server", users },
|
||||
});
|
||||
|
||||
// Ensure no current user
|
||||
expect(localStorage["Comfy.userId"]).toBeFalsy();
|
||||
expect(localStorage["Comfy.userName"]).toBeFalsy();
|
||||
|
||||
await waitForUserScreenShow();
|
||||
|
||||
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
|
||||
expect(selection).toBeTruthy();
|
||||
|
||||
// Ensure login is visible
|
||||
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
|
||||
// Ensure menu is hidden
|
||||
const menu = document.querySelectorAll(".comfy-menu")?.[0];
|
||||
expect(window.getComputedStyle(menu)?.display).toBe("none");
|
||||
|
||||
const isCreate = await onShown(selection);
|
||||
|
||||
// Submit form
|
||||
selection.querySelectorAll("form")[0].submit();
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
|
||||
// Wait for start
|
||||
const s = await starting;
|
||||
|
||||
// Ensure login is removed
|
||||
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
|
||||
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
|
||||
|
||||
// Ensure settings + templates are saved
|
||||
const { api } = require("../../comfy/web/scripts/api");
|
||||
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
|
||||
if (isCreate) {
|
||||
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
|
||||
expect(s.app.isNewUserSession).toBeTruthy();
|
||||
} else {
|
||||
expect(s.app.isNewUserSession).toBeFalsy();
|
||||
}
|
||||
|
||||
return { users, selection, ...s };
|
||||
}
|
||||
|
||||
it("allows user creation if no users", async () => {
|
||||
const { users } = await testUserScreen((selection) => {
|
||||
// Ensure we have no users flag added
|
||||
expect(selection.classList.contains("no-users")).toBeTruthy();
|
||||
|
||||
// Enter a username
|
||||
const input = selection.getElementsByTagName("input")[0];
|
||||
input.focus();
|
||||
input.value = "Test User";
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
expect(users).toStrictEqual({
|
||||
"Test User!": "Test User",
|
||||
});
|
||||
|
||||
expect(localStorage["Comfy.userId"]).toBe("Test User!");
|
||||
expect(localStorage["Comfy.userName"]).toBe("Test User");
|
||||
});
|
||||
it("allows user creation if no current user but other users", async () => {
|
||||
const users = {
|
||||
"Test User 2!": "Test User 2",
|
||||
};
|
||||
|
||||
await testUserScreen((selection) => {
|
||||
expect(selection.classList.contains("no-users")).toBeFalsy();
|
||||
|
||||
// Enter a username
|
||||
const input = selection.getElementsByTagName("input")[0];
|
||||
input.focus();
|
||||
input.value = "Test User 3";
|
||||
return true;
|
||||
}, users);
|
||||
|
||||
expect(users).toStrictEqual({
|
||||
"Test User 2!": "Test User 2",
|
||||
"Test User 3!": "Test User 3",
|
||||
});
|
||||
|
||||
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
|
||||
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
|
||||
});
|
||||
it("allows user selection if no current user but other users", async () => {
|
||||
const users = {
|
||||
"A!": "A",
|
||||
"B!": "B",
|
||||
"C!": "C",
|
||||
};
|
||||
|
||||
await testUserScreen((selection) => {
|
||||
expect(selection.classList.contains("no-users")).toBeFalsy();
|
||||
|
||||
// Check user list
|
||||
const select = selection.getElementsByTagName("select")[0];
|
||||
const options = select.getElementsByTagName("option");
|
||||
expect(
|
||||
[...options]
|
||||
.filter((o) => !o.disabled)
|
||||
.reduce((p, n) => {
|
||||
p[n.getAttribute("value")] = n.textContent;
|
||||
return p;
|
||||
}, {})
|
||||
).toStrictEqual(users);
|
||||
|
||||
// Select an option
|
||||
select.focus();
|
||||
select.value = options[2].value;
|
||||
|
||||
return false;
|
||||
}, users);
|
||||
|
||||
expect(users).toStrictEqual(users);
|
||||
|
||||
expect(localStorage["Comfy.userId"]).toBe("B!");
|
||||
expect(localStorage["Comfy.userName"]).toBe("B");
|
||||
});
|
||||
it("doesnt show user screen if current user", async () => {
|
||||
const starting = start({
|
||||
resetEnv: true,
|
||||
userConfig: {
|
||||
storage: "server",
|
||||
users: {
|
||||
"User!": "User",
|
||||
},
|
||||
},
|
||||
localStorage: {
|
||||
"Comfy.userId": "User!",
|
||||
"Comfy.userName": "User",
|
||||
},
|
||||
});
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
|
||||
expectNoUserScreen();
|
||||
|
||||
await starting;
|
||||
});
|
||||
it("allows user switching", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: {
|
||||
storage: "server",
|
||||
users: {
|
||||
"User!": "User",
|
||||
},
|
||||
},
|
||||
localStorage: {
|
||||
"Comfy.userId": "User!",
|
||||
"Comfy.userName": "User",
|
||||
},
|
||||
});
|
||||
|
||||
// cant actually test switching user easily but can check the setting is present
|
||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
|
||||
});
|
||||
});
|
||||
describe("single-user", () => {
|
||||
it("doesnt show user creation if no default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: false, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../comfy/web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(1);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(1);
|
||||
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
|
||||
expect(app.isNewUserSession).toBeTruthy();
|
||||
});
|
||||
it("doesnt show user creation if default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../comfy/web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
||||
expect(app.isNewUserSession).toBeFalsy();
|
||||
});
|
||||
it("doesnt allow user switching", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
|
||||
});
|
||||
});
|
||||
describe("browser-user", () => {
|
||||
it("doesnt show user creation if no default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: false, storage: "browser" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../comfy/web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
||||
expect(app.isNewUserSession).toBeFalsy();
|
||||
});
|
||||
it("doesnt show user creation if default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../comfy/web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
||||
expect(app.isNewUserSession).toBeFalsy();
|
||||
});
|
||||
it("doesnt allow user switching", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "browser" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -1,10 +1,18 @@
|
||||
const { mockApi } = require("./setup");
|
||||
const { Ez } = require("./ezgraph");
|
||||
const lg = require("./litegraph");
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
|
||||
const html = fs.readFileSync(path.resolve("../comfy/web/index.html"))
|
||||
|
||||
/**
|
||||
*
|
||||
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
|
||||
* @param { Parameters<typeof mockApi>[0] & {
|
||||
* resetEnv?: boolean,
|
||||
* preSetup?(app): Promise<void>,
|
||||
* localStorage?: Record<string, string>
|
||||
* } } config
|
||||
* @returns
|
||||
*/
|
||||
export async function start(config = {}) {
|
||||
@ -12,12 +20,18 @@ export async function start(config = {}) {
|
||||
jest.resetModules();
|
||||
jest.resetAllMocks();
|
||||
lg.setup(global);
|
||||
localStorage.clear();
|
||||
sessionStorage.clear();
|
||||
}
|
||||
|
||||
Object.assign(localStorage, config.localStorage ?? {});
|
||||
document.body.innerHTML = html;
|
||||
|
||||
mockApi(config);
|
||||
const { app } = require("../../comfy/web/scripts/app");
|
||||
config.preSetup?.(app);
|
||||
await app.setup();
|
||||
|
||||
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
|
||||
}
|
||||
|
||||
|
||||
@ -18,9 +18,21 @@ function* walkSync(dir) {
|
||||
*/
|
||||
|
||||
/**
|
||||
* @param { { mockExtensions?: string[], mockNodeDefs?: Record<string, ComfyObjectInfo> } } config
|
||||
* @param {{
|
||||
* mockExtensions?: string[],
|
||||
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
|
||||
* settings?: Record<string, string>
|
||||
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
|
||||
* userData?: Record<string, any>
|
||||
* }} config
|
||||
*/
|
||||
export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
|
||||
export function mockApi(config = {}) {
|
||||
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
|
||||
userConfig,
|
||||
settings: {},
|
||||
userData: {},
|
||||
...config,
|
||||
};
|
||||
if (!mockExtensions) {
|
||||
mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core")))
|
||||
.filter((x) => x.endsWith(".js"))
|
||||
@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
|
||||
getNodeDefs: jest.fn(() => mockNodeDefs),
|
||||
init: jest.fn(),
|
||||
apiURL: jest.fn((x) => "../../web/" + x),
|
||||
createUser: jest.fn((username) => {
|
||||
if(username in userConfig.users) {
|
||||
return { status: 400, json: () => "Duplicate" }
|
||||
}
|
||||
userConfig.users[username + "!"] = username;
|
||||
return { status: 200, json: () => username + "!" }
|
||||
}),
|
||||
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
|
||||
getSettings: jest.fn(() => settings),
|
||||
storeSettings: jest.fn((v) => Object.assign(settings, v)),
|
||||
getUserData: jest.fn((f) => {
|
||||
if (f in userData) {
|
||||
return { status: 200, json: () => userData[f] };
|
||||
} else {
|
||||
return { status: 404 };
|
||||
}
|
||||
}),
|
||||
storeUserData: jest.fn((file, data) => {
|
||||
userData[file] = data;
|
||||
}),
|
||||
};
|
||||
jest.mock("../../comfy/web/scripts/api", () => ({
|
||||
get api() {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user