Merge with latest upstream.

This commit is contained in:
doctorpangloss 2024-01-29 14:41:30 -08:00
commit 82edb2ff0e
70 changed files with 3546 additions and 536 deletions

3
.gitignore vendored
View File

@ -172,4 +172,5 @@ dmypy.json
cython_debug/
.openapi-generator/
/tests-ui/data/object_info.json
/tests-ui/data/object_info.json
/user/

View File

@ -101,6 +101,8 @@ There is a portable standalone build for Windows that should work for running on
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -192,7 +194,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
To run tests:
```shell
pytest tests/inference
(cd tests-ui && npm ci && npm test:generate && npm test)
(cd tests-ui && npm ci && npm run test:generate && npm test)
```
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.

0
comfy/app/__init__.py Normal file
View File

54
comfy/app/app_settings.py Normal file
View File

@ -0,0 +1,54 @@
import os
import json
from aiohttp import web
class AppSettings():
def __init__(self, user_manager):
self.user_manager = user_manager
def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
else:
return {}
def save_settings(self, request, settings):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
with open(file, "w") as f:
f.write(json.dumps(settings, indent=4))
def add_routes(self, routes):
@routes.get("/settings")
async def get_settings(request):
return web.json_response(self.get_settings(request))
@routes.get("/settings/{id}")
async def get_setting(request):
value = None
settings = self.get_settings(request)
setting_id = request.match_info.get("id", None)
if setting_id and setting_id in settings:
value = settings[setting_id]
return web.json_response(value)
@routes.post("/settings")
async def post_settings(request):
settings = self.get_settings(request)
new_settings = await request.json()
self.save_settings(request, {**settings, **new_settings})
return web.Response(status=200)
@routes.post("/settings/{id}")
async def post_setting(request):
setting_id = request.match_info.get("id", None)
if not setting_id:
return web.Response(status=400)
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)

135
comfy/app/user_manager.py Normal file
View File

@ -0,0 +1,135 @@
import json
import os
import re
import uuid
from aiohttp import web
from ..cli_args import args
from ..cmd.folder_paths import user_directory
from .app_settings import AppSettings
class UserManager():
def __init__(self):
self.default_user = "default"
self.users_file = os.path.join(user_directory, "users.json")
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.users_file):
with open(self.users_file) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
if user not in self.users:
raise KeyError("Unknown user: " + user)
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
return None
parent = user_root
if file is not None:
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
return None
if create_dir and not os.path.exists(parent):
os.mkdir(parent)
return path
def add_user(self, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
with open(self.users_file, "w") as f:
json.dump(self.users, f)
return user_id
def add_routes(self, routes):
self.settings.add_routes(routes)
@routes.get("/users")
async def get_users(request):
if args.multi_user:
return web.json_response({"storage": "server", "users": self.users})
else:
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
return web.json_response({
"storage": "server",
"migrated": os.path.exists(user_dir)
})
@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
return web.json_response(user_id)
@routes.get("/userdata/{file}")
async def getuserdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if not os.path.exists(path):
return web.Response(status=404)
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
body = await request.read()
with open(path, "wb") as f:
f.write(body)
return web.Response(status=200)

View File

@ -112,6 +112,8 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
if options.args_parsing:
args = parser.parse_args()
else:

View File

@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module):
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:

View File

@ -1,4 +1,4 @@
from .utils import load_torch_file, transformers_convert
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
import json
@ -43,6 +43,9 @@ class ClipVisionModel():
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)).float()
@ -75,6 +78,9 @@ def convert_to_transformers(sd, prefix):
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import sys
import copy
import datetime
import heapq
@ -15,6 +16,7 @@ from typing import Tuple
import sys
import gc
import inspect
from typing import List, Literal, NamedTuple, Optional
import torch
@ -325,11 +327,22 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor:
def __init__(self, server):
self.success = None
self.server = server
self.reset()
def reset(self):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.status_messages = []
self.success = True
self.old_prompt = {}
self.server = server
def add_message(self, event, data, broadcast: bool):
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
@ -344,23 +357,22 @@ class PromptExecutor:
"node_type": class_type,
"executed": list(executed),
}
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
self.add_message("execution_interrupted", mes, broadcast=True)
else:
if self.server.client_id is not None:
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.server.send_sync("execution_error", mes, self.server.client_id)
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
@ -381,8 +393,8 @@ class PromptExecutor:
else:
self.server.client_id = None
if self.server.client_id is not None:
self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id)
self.status_messages = []
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
# delete cached outputs if nodes don't exist for them
@ -415,9 +427,9 @@ class PromptExecutor:
del d
model_management.cleanup_models()
if self.server.client_id is not None:
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
self.server.client_id)
self.add_message("execution_cached",
{"nodes": list(current_outputs), "prompt_id": prompt_id},
broadcast=False)
executed = set()
output_node_id = None
to_execute = []
@ -434,11 +446,9 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id,
self.outputs_ui, self.object_storage)
if success is not True:
self.handle_execution_error( prompt_id,
prompt, current_outputs, executed, error, ex)
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break
for x in executed:
@ -779,6 +789,7 @@ class PromptQueue:
self.queue = []
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self
def size(self) -> int:
@ -803,15 +814,28 @@ class PromptQueue:
self.server.queue_updated()
return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id, outputs: dict):
class ExecutionStatus(NamedTuple):
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
def task_done(self, item_id, outputs: dict,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex:
queue_item = self.currently_running.pop(item_id)
prompt = queue_item.queue_tuple
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": time.time()}
for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o]
status_dict: Optional[dict] = None
if status is not None:
status_dict = copy.deepcopy(status._asdict())
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
'status': status_dict,
}
self.server.queue_updated()
if queue_item.completed:
queue_item.completed.set_result(outputs)
@ -880,3 +904,17 @@ class PromptQueue:
def delete_history_item(self, id_to_delete: int):
with self.mutex:
self.history.pop(id_to_delete, None)
def set_flag(self, name, data):
with self.mutex:
self.flags[name] = data
self.not_empty.notify()
def get_flags(self, reset=True):
with self.mutex:
if reset:
ret = self.flags
self.flags = {}
return ret
else:
return self.flags.copy()

View File

@ -31,11 +31,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache = {}
@ -139,15 +142,27 @@ def recursive_search(directory, excluded_dir_names=None):
excluded_dir_names = []
result = []
dirs = {directory: os.path.getmtime(directory)}
dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
for d in subdirs:
path = os.path.join(dirpath, d)
dirs[path] = os.path.getmtime(path)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs
def filter_files_extensions(files, extensions):

View File

@ -89,7 +89,7 @@ def prompt_worker(q, _server):
gc_collect_interval = 10.0
current_time = 0.0
while True:
timeout = None
timeout = 1000.0
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -102,14 +102,32 @@ def prompt_worker(q, _server):
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
q.task_done(item_id,
e.outputs_ui,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if _server.client_id is not None:
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
if flags.get("unload_models", free_memory):
model_management.unload_all_models()
need_gc = True
last_gc_collect = 0
if free_memory:
e.reset()
need_gc = True
last_gc_collect = 0
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:

View File

@ -35,6 +35,7 @@ from ..vendor.appdirs import user_data_dir
nodes = import_all_nodes_in_workspace()
from ..app.user_manager import UserManager
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@ -91,6 +92,7 @@ class PromptServer():
mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
@ -532,6 +534,17 @@ class PromptServer():
model_management.interrupt_current_processing()
return web.Response(status=200)
@routes.post("/free")
async def post_free(request):
json_data = await request.json()
unload_models = json_data.get("unload_models", False)
free_memory = json_data.get("free_memory", False)
if unload_models:
self.prompt_queue.set_flag("unload_models", unload_models)
if free_memory:
self.prompt_queue.set_flag("free_memory", free_memory)
return web.Response(status=200)
@routes.post("/history")
async def post_history(request):
json_data = await request.json()
@ -671,6 +684,7 @@ class PromptServer():
return web.json_response(prompt, status=200)
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
@ -768,14 +782,9 @@ class PromptServer():
site = web.TCPSite(runner, address, port)
await site.start()
address_to_print = 'localhost'
if address == '' or address == '0.0.0.0':
address = '0.0.0.0'
else:
address_to_print = address
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address_to_print, port))
print("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None:
call_on_start(address, port)

View File

@ -1,4 +1,3 @@
import enum
import torch
import math
from . import utils

View File

@ -1,7 +1,6 @@
import torch
import math
import os
import contextlib
from . import utils
from . import model_management
@ -127,7 +126,10 @@ class ControlBase:
if o[i] is None:
o[i] = prev_val
else:
o[i] += prev_val
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out
class ControlNet(ControlBase):

View File

@ -1,7 +1,7 @@
import math
import torch
from torch import nn, einsum
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction

View File

@ -1,12 +1,9 @@
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@ -176,6 +173,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False,
upcast_attention=upcast_attention,
mask=mask,
)
hidden_states = hidden_states.to(dtype)
@ -238,6 +236,12 @@ def attention_split(q, k, v, heads, mask=None):
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
@ -293,11 +297,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
@ -322,7 +329,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
@ -338,15 +344,18 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
optimized_attention_masked = optimized_attention
def optimized_attention_for_device(device, mask=False):
if device == torch.device("cpu"): #TODO
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
return attention_pytorch
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else:
return attention_basic
if device == torch.device("cpu"):
return attention_sub_quad
if mask:
return optimized_attention_masked

View File

@ -1,12 +1,9 @@
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import (
checkpoint,

View File

@ -23,13 +23,13 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
def forward(self, x, noise_level=None):
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.q_sample(x, noise_level, seed=seed)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level

View File

@ -61,6 +61,7 @@ def _summarize_chunk(
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
) -> AttnChunk:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
@ -84,6 +85,8 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
attn_weights -= max_score
if mask is not None:
attn_weights += mask
torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
@ -96,11 +99,12 @@ def _query_chunk_attention(
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
mask,
) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
key_chunk = dynamic_slice(
key_t,
(0, 0, chunk_idx),
@ -111,10 +115,13 @@ def _query_chunk_attention(
(0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head)
)
return summarize_chunk(query, key_chunk, value_chunk)
if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
) -> Tensor:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
beta=0,
)
if mask is not None:
attn_scores += mask
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
mask = None,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@ -209,13 +220,22 @@ def efficient_dot_product_attention(
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
if mask is not None and len(mask.shape) == 2:
mask = mask.unsqueeze(0)
def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice(
query,
(0, chunk_idx, 0),
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
)
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
query=query,
key_t=key_t,
value=value,
mask=mask,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
query=get_query_chunk(i * query_chunk_size),
key_t=key_t,
value=value,
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res

View File

@ -6,7 +6,6 @@ from . import model_management
from . import conds
from . import ops
from enum import Enum
import contextlib
from . import utils
class ModelType(Enum):
@ -100,11 +99,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
@ -117,9 +134,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device))
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
@ -159,19 +176,28 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
for sd in extra_sds:
unet_state_dict.update(sd)
return unet_state_dict
def set_inpaint(self):
self.inpaint_model = True
@ -190,7 +216,7 @@ class BaseModel(torch.nn.Module):
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
weights = []
noise_aug = []
@ -199,7 +225,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
@ -225,11 +251,11 @@ class SD21UNCLIP(BaseModel):
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
else:
return args["pooled_output"]

View File

@ -176,7 +176,7 @@ try:
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported():
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:

View File

@ -166,6 +166,26 @@ class ConditioningSetAreaPercentage:
c.append(n)
return (c, )
class ConditioningSetAreaStrength:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
@ -341,6 +361,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -1400,6 +1476,8 @@ class LoadImage:
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
@ -1455,6 +1533,8 @@ class LoadImageMask:
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = channel[0].upper()
@ -1609,10 +1689,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size()
new_image = torch.zeros(
new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32,
)
) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones(
@ -1678,6 +1759,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningConcat": ConditioningConcat,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
@ -1704,6 +1786,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,

View File

@ -1,10 +1,9 @@
import torch
from contextlib import contextmanager
import comfy.model_management
from . import model_management
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
non_blocking = model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)

View File

@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)

View File

@ -531,7 +531,14 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
clip_sd = None
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -7,7 +7,6 @@ import traceback
import zipfile
from . import model_management
from pkg_resources import resource_filename
import contextlib
from . import clip_model
import json

View File

@ -65,6 +65,12 @@ class BASE:
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -1,6 +1,7 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { mergeIfValid } from "./widgetInputs.js";
import { ManageGroupDialog } from "./groupNodeManage.js";
const GROUP = Symbol();
@ -61,11 +62,7 @@ class GroupNodeBuilder {
);
return;
case Workflow.InUse.Registered:
if (
!confirm(
"An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?"
)
) {
if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) {
return;
}
break;
@ -151,6 +148,8 @@ export class GroupNodeConfig {
this.primitiveDefs = {};
this.widgetToPrimitive = {};
this.primitiveToWidget = {};
this.nodeInputs = {};
this.outputVisibility = [];
}
async registerType(source = "workflow") {
@ -158,6 +157,7 @@ export class GroupNodeConfig {
output: [],
output_name: [],
output_is_list: [],
output_is_hidden: [],
name: source + "/" + this.name,
display_name: this.name,
category: "group nodes" + ("/" + source),
@ -277,8 +277,7 @@ export class GroupNodeConfig {
}
if (input.widget) {
const targetDef = globalDefs[node.type];
const targetWidget =
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
const widget = [targetWidget[0], config];
const res = mergeIfValid(
@ -330,7 +329,8 @@ export class GroupNodeConfig {
}
getInputConfig(node, inputName, seenInputs, config, extra) {
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName];
let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
let key = name;
let prefix = "";
// Special handling for primitive to include the title if it is set rather than just "value"
@ -349,14 +349,14 @@ export class GroupNodeConfig {
}
if (config[0] === "IMAGEUPLOAD") {
if (!extra) extra = {};
extra.widget = `${prefix}${config[1]?.widget ?? "image"}`;
extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image";
}
if (extra) {
config = [config[0], { ...config[1], ...extra }];
}
return { name, config };
return { name, config, customConfig };
}
processWidgetInputs(inputs, node, inputNames, seenInputs) {
@ -366,9 +366,7 @@ export class GroupNodeConfig {
for (const inputName of inputNames) {
let widgetType = app.getWidgetType(inputs[inputName], inputName);
if (widgetType) {
const convertedIndex = node.inputs?.findIndex(
(inp) => inp.name === inputName && inp.widget?.name === inputName
);
const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName);
if (convertedIndex > -1) {
// This widget has been converted to a widget
// We need to store this in the correct position so link ids line up
@ -424,6 +422,7 @@ export class GroupNodeConfig {
}
processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) {
this.nodeInputs[node.index] = {};
for (let i = 0; i < slots.length; i++) {
const inputName = slots[i];
if (linksTo[i]) {
@ -432,7 +431,11 @@ export class GroupNodeConfig {
continue;
}
const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
this.nodeInputs[node.index][inputName] = name;
if(customConfig?.visible === false) continue;
this.nodeDef.input.required[name] = config;
inputMap[i] = this.inputCount++;
}
@ -452,6 +455,7 @@ export class GroupNodeConfig {
const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], {
defaultInput: true,
});
this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName };
@ -477,9 +481,7 @@ export class GroupNodeConfig {
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
this.#convertedToProcess.push(() =>
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
);
this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs));
return inputMapping;
}
@ -490,8 +492,12 @@ export class GroupNodeConfig {
// Add outputs
for (let outputId = 0; outputId < def.output.length; outputId++) {
const linksFrom = this.linksFrom[node.index];
if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) {
// This output is linked internally so we can skip it
// If this output is linked internally we flag it to hide
const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId];
const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId];
const visible = customConfig?.visible ?? !hasLink;
this.outputVisibility.push(visible);
if (!visible) {
continue;
}
@ -500,11 +506,15 @@ export class GroupNodeConfig {
this.nodeDef.output.push(def.output[outputId]);
this.nodeDef.output_is_list.push(def.output_is_list[outputId]);
let label = def.output_name?.[outputId] ?? def.output[outputId];
const output = node.outputs.find((o) => o.name === label);
if (output?.label) {
label = output.label;
let label = customConfig?.name;
if (!label) {
label = def.output_name?.[outputId] ?? def.output[outputId];
const output = node.outputs.find((o) => o.name === label);
if (output?.label) {
label = output.label;
}
}
let name = label;
if (name in seenOutputs) {
const prefix = `${node.title ?? node.type} `;
@ -677,6 +687,25 @@ export class GroupNodeHandler {
return this.innerNodes;
};
this.node.recreate = async () => {
const id = this.node.id;
const sz = this.node.size;
const nodes = this.node.convertToNodes();
const groupNode = LiteGraph.createNode(this.node.type);
groupNode.id = id;
// Reuse the existing nodes for this instance
groupNode.setInnerNodes(nodes);
groupNode[GROUP].populateWidgets();
app.graph.add(groupNode);
groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])];
// Remove all converted nodes and relink them
groupNode[GROUP].replaceNodes(nodes);
return groupNode;
};
this.node.convertToNodes = () => {
const addInnerNodes = () => {
const backup = localStorage.getItem("litegrapheditor_clipboard");
@ -769,6 +798,7 @@ export class GroupNodeHandler {
const slot = node.inputs[groupSlotId];
if (slot.link == null) continue;
const link = app.graph.links[slot.link];
if (!link) continue;
// connect this node output to the input of another node
const originNode = app.graph.getNodeById(link.origin_id);
originNode.connect(link.origin_slot, newNode, +innerInputId);
@ -806,12 +836,23 @@ export class GroupNodeHandler {
let optionIndex = options.findIndex((o) => o.content === "Outputs");
if (optionIndex === -1) optionIndex = options.length;
else optionIndex++;
options.splice(optionIndex, 0, null, {
content: "Convert to nodes",
callback: () => {
return this.convertToNodes();
options.splice(
optionIndex,
0,
null,
{
content: "Convert to nodes",
callback: () => {
return this.convertToNodes();
},
},
});
{
content: "Manage Group Node",
callback: () => {
new ManageGroupDialog(app).show(this.type);
},
}
);
};
// Draw custom collapse icon to identity this as a group
@ -843,6 +884,7 @@ export class GroupNodeHandler {
const r = onDrawForeground?.apply?.(this, arguments);
if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) {
const n = groupData.nodes[this.runningInternalNodeId];
if(!n) return;
const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`;
ctx.save();
ctx.font = "12px sans-serif";
@ -865,6 +907,28 @@ export class GroupNodeHandler {
return onExecutionStart?.apply(this, arguments);
};
const self = this;
const onNodeCreated = this.node.onNodeCreated;
this.node.onNodeCreated = function () {
const config = self.groupData.nodeData.config;
if (config) {
for (const n in config) {
const inputs = config[n]?.input;
for (const w in inputs) {
if (inputs[w].visible !== false) continue;
const widgetName = self.groupData.oldToNewWidgetMap[n][w];
const widget = this.widgets.find((w) => w.name === widgetName);
if (widget) {
widget.type = "hidden";
widget.computeSize = () => [0, -4];
}
}
}
}
return onNodeCreated?.apply(this, arguments);
};
function handleEvent(type, getId, getEvent) {
const handler = ({ detail }) => {
const id = getId(detail);
@ -902,6 +966,26 @@ export class GroupNodeHandler {
api.removeEventListener("executing", executing);
api.removeEventListener("executed", executed);
};
this.node.refreshComboInNode = (defs) => {
// Update combo widget options
for (const widgetName in this.groupData.newToOldWidgetMap) {
const widget = this.node.widgets.find((w) => w.name === widgetName);
if (widget?.type === "combo") {
const old = this.groupData.newToOldWidgetMap[widgetName];
const def = defs[old.node.type];
const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName];
if (!input) continue;
widget.options.values = input[0];
if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) {
widget.value = widget.options.values[0];
widget.callback(widget.value);
}
}
}
};
}
updateInnerWidgets() {
@ -927,13 +1011,15 @@ export class GroupNodeHandler {
continue;
} else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index];
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot];
if (input.widget) {
const widget = node.widgets?.find((w) => w.name === input.widget.name);
if (widget) {
widget.value = newValue;
if (rerouteLinks) {
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot];
if (input.widget) {
const widget = node.widgets?.find((w) => w.name === input.widget.name);
if (widget) {
widget.value = newValue;
}
}
}
}
@ -975,7 +1061,7 @@ export class GroupNodeHandler {
const [, , targetNodeId, targetNodeSlot] = link;
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
const inputs = targetNode.inputs;
const targetWidget = inputs?.[targetNodeSlot].widget;
const targetWidget = inputs?.[targetNodeSlot]?.widget;
if (!targetWidget) return;
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
@ -983,13 +1069,12 @@ export class GroupNodeHandler {
if (v == null) return;
const widgetName = Object.values(map)[0];
const widget = this.node.widgets.find(w => w.name === widgetName);
if(widget) {
const widget = this.node.widgets.find((w) => w.name === widgetName);
if (widget) {
widget.value = v;
}
}
populateWidgets() {
if (!this.node.widgets) return;
@ -1080,7 +1165,7 @@ export class GroupNodeHandler {
}
static getGroupData(node) {
return node.constructor?.nodeData?.[GROUP];
return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP];
}
static isGroupNode(node) {
@ -1112,7 +1197,7 @@ export class GroupNodeHandler {
}
function addConvertToGroupOptions() {
function addOption(options, index) {
function addConvertOption(options, index) {
const selected = Object.values(app.canvas.selected_nodes ?? {});
const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n));
options.splice(index + 1, null, {
@ -1124,12 +1209,25 @@ function addConvertToGroupOptions() {
});
}
function addManageOption(options, index) {
const groups = app.graph.extra?.groupNodes;
const disabled = !groups || !Object.keys(groups).length;
options.splice(index + 1, null, {
content: `Manage Group Nodes`,
disabled,
callback: () => {
new ManageGroupDialog(app).show();
},
});
}
// Add to canvas
const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions;
LGraphCanvas.prototype.getCanvasMenuOptions = function () {
const options = getCanvasMenuOptions.apply(this, arguments);
const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length;
addOption(options, index);
addConvertOption(options, index);
addManageOption(options, index + 1);
return options;
};
@ -1139,7 +1237,7 @@ function addConvertToGroupOptions() {
const options = getNodeMenuOptions.apply(this, arguments);
if (!GroupNodeHandler.isGroupNode(node)) {
const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1;
addOption(options, index);
addConvertOption(options, index);
}
return options;
};
@ -1167,6 +1265,14 @@ const ext = {
node[GROUP] = new GroupNodeHandler(node);
}
},
async refreshComboInNodes(defs) {
// Re-register group nodes so new ones are created with the correct options
Object.assign(globalDefs, defs);
const nodes = app.graph.extra?.groupNodes;
if (nodes) {
await GroupNodeConfig.registerFromWorkflow(nodes, {});
}
}
};
app.registerExtension(ext);

View File

@ -0,0 +1,149 @@
.comfy-group-manage {
background: var(--bg-color);
color: var(--fg-color);
padding: 0;
font-family: Arial, Helvetica, sans-serif;
border-color: black;
margin: 20vh auto;
max-height: 60vh;
}
.comfy-group-manage-outer {
max-height: 60vh;
min-width: 500px;
display: flex;
flex-direction: column;
}
.comfy-group-manage-outer > header {
display: flex;
align-items: center;
gap: 10px;
justify-content: space-between;
background: var(--comfy-menu-bg);
padding: 15px 20px;
}
.comfy-group-manage-outer > header select {
background: var(--comfy-input-bg);
border: 1px solid var(--border-color);
color: var(--input-text);
padding: 5px 10px;
border-radius: 5px;
}
.comfy-group-manage h2 {
margin: 0;
font-weight: normal;
}
.comfy-group-manage main {
display: flex;
overflow: hidden;
}
.comfy-group-manage .drag-handle {
font-weight: bold;
}
.comfy-group-manage-list {
border-right: 1px solid var(--comfy-menu-bg);
}
.comfy-group-manage-list ul {
margin: 40px 0 0;
padding: 0;
list-style: none;
}
.comfy-group-manage-list-items {
max-height: 70vh;
overflow-y: scroll;
overflow-x: hidden;
}
.comfy-group-manage-list li {
display: flex;
padding: 10px 20px 10px 10px;
cursor: pointer;
align-items: center;
gap: 5px;
}
.comfy-group-manage-list div {
display: flex;
flex-direction: column;
}
.comfy-group-manage-list li:not(.selected):hover div {
text-decoration: underline;
}
.comfy-group-manage-list li.selected {
background: var(--border-color);
}
.comfy-group-manage-list li span {
opacity: 0.7;
font-size: smaller;
}
.comfy-group-manage-node {
flex: auto;
background: var(--border-color);
display: flex;
flex-direction: column;
}
.comfy-group-manage-node > div {
overflow: auto;
}
.comfy-group-manage-node header {
display: flex;
background: var(--bg-color);
height: 40px;
}
.comfy-group-manage-node header a {
text-align: center;
flex: auto;
border-right: 1px solid var(--comfy-menu-bg);
border-bottom: 1px solid var(--comfy-menu-bg);
padding: 10px;
cursor: pointer;
font-size: 15px;
}
.comfy-group-manage-node header a:last-child {
border-right: none;
}
.comfy-group-manage-node header a:not(.active):hover {
text-decoration: underline;
}
.comfy-group-manage-node header a.active {
background: var(--border-color);
border-bottom: none;
}
.comfy-group-manage-node-page {
display: none;
overflow: auto;
}
.comfy-group-manage-node-page.active {
display: block;
}
.comfy-group-manage-node-page div {
padding: 10px;
display: flex;
align-items: center;
gap: 10px;
}
.comfy-group-manage-node-page input {
border: none;
color: var(--input-text);
background: var(--comfy-input-bg);
padding: 5px 10px;
}
.comfy-group-manage-node-page input[type="text"] {
flex: auto;
}
.comfy-group-manage-node-page label {
display: flex;
gap: 5px;
align-items: center;
}
.comfy-group-manage footer {
border-top: 1px solid var(--comfy-menu-bg);
padding: 10px;
display: flex;
gap: 10px;
}
.comfy-group-manage footer button {
font-size: 14px;
padding: 5px 10px;
border-radius: 0;
}
.comfy-group-manage footer button:first-child {
margin-right: auto;
}

View File

@ -0,0 +1,422 @@
import { $el, ComfyDialog } from "../../scripts/ui.js";
import { DraggableList } from "../../scripts/ui/draggableList.js";
import { addStylesheet } from "../../scripts/utils.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
addStylesheet(import.meta.url);
const ORDER = Symbol();
function merge(target, source) {
if (typeof target === "object" && typeof source === "object") {
for (const key in source) {
const sv = source[key];
if (typeof sv === "object") {
let tv = target[key];
if (!tv) tv = target[key] = {};
merge(tv, source[key]);
} else {
target[key] = sv;
}
}
}
return target;
}
export class ManageGroupDialog extends ComfyDialog {
/** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */
tabs = {};
/** @type { number | null | undefined } */
selectedNodeIndex;
/** @type { keyof ManageGroupDialog["tabs"] } */
selectedTab = "Inputs";
/** @type { string | undefined } */
selectedGroup;
/** @type { Record<string, Record<string, Record<string, { name?: string | undefined, visible?: boolean | undefined }>>> } */
modifications = {};
get selectedNodeInnerIndex() {
return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex;
}
constructor(app) {
super();
this.app = app;
this.element = $el("dialog.comfy-group-manage", {
parent: document.body,
});
}
changeTab(tab) {
this.tabs[this.selectedTab].tab.classList.remove("active");
this.tabs[this.selectedTab].page.classList.remove("active");
this.tabs[tab].tab.classList.add("active");
this.tabs[tab].page.classList.add("active");
this.selectedTab = tab;
}
changeNode(index, force) {
if (!force && this.selectedNodeIndex === index) return;
if (this.selectedNodeIndex != null) {
this.nodeItems[this.selectedNodeIndex].classList.remove("selected");
}
this.nodeItems[index].classList.add("selected");
this.selectedNodeIndex = index;
if (!this.buildInputsPage() && this.selectedTab === "Inputs") {
this.changeTab("Widgets");
}
if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") {
this.changeTab("Outputs");
}
if (!this.buildOutputsPage() && this.selectedTab === "Outputs") {
this.changeTab("Inputs");
}
this.changeTab(this.selectedTab);
}
getGroupData() {
this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup];
this.groupNodeDef = this.groupNodeType.nodeData;
this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType);
}
changeGroup(group, reset = true) {
this.selectedGroup = group;
this.getGroupData();
const nodes = this.groupData.nodeData.nodes;
this.nodeItems = nodes.map((n, i) =>
$el(
"li.draggable-item",
{
dataset: {
nodeindex: n.index + "",
},
onclick: () => {
this.changeNode(i);
},
},
[
$el("span.drag-handle"),
$el(
"div",
{
textContent: n.title ?? n.type,
},
n.title
? $el("span", {
textContent: n.type,
})
: []
),
]
)
);
this.innerNodesList.replaceChildren(...this.nodeItems);
if (reset) {
this.selectedNodeIndex = null;
this.changeNode(0);
} else {
const items = this.draggable.getAllItems();
let index = items.findIndex(item => item.classList.contains("selected"));
if(index === -1) index = this.selectedNodeIndex;
this.changeNode(index, true);
}
const ordered = [...nodes];
this.draggable?.dispose();
this.draggable = new DraggableList(this.innerNodesList, "li");
this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => {
if (oldPosition === newPosition) return;
ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]);
for (let i = 0; i < ordered.length; i++) {
this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i });
}
});
}
storeModification({ nodeIndex, section, prop, value }) {
const groupMod = (this.modifications[this.selectedGroup] ??= {});
const nodesMod = (groupMod.nodes ??= {});
const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {});
const typeMod = (nodeMod[section] ??= {});
if (typeof value === "object") {
const objMod = (typeMod[prop] ??= {});
Object.assign(objMod, value);
} else {
typeMod[prop] = value;
}
}
getEditElement(section, prop, value, placeholder, checked, checkable = true) {
if (value === placeholder) value = "";
const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop];
if (mods) {
if (mods.name != null) {
value = mods.name;
}
if (mods.visible != null) {
checked = mods.visible;
}
}
return $el("div", [
$el("input", {
value,
placeholder,
type: "text",
onchange: (e) => {
this.storeModification({ section, prop, value: { name: e.target.value } });
},
}),
$el("label", { textContent: "Visible" }, [
$el("input", {
type: "checkbox",
checked,
disabled: !checkable,
onchange: (e) => {
this.storeModification({ section, prop, value: { visible: !!e.target.checked } });
},
}),
]),
]);
}
buildWidgetsPage() {
const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex];
const items = Object.keys(widgets ?? {});
const type = app.graph.extra.groupNodes[this.selectedGroup];
const config = type.config?.[this.selectedNodeInnerIndex]?.input;
this.widgetsPage.replaceChildren(
...items.map((oldName) => {
return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false);
})
);
return !!items.length;
}
buildInputsPage() {
const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex];
const items = Object.keys(inputs ?? {});
const type = app.graph.extra.groupNodes[this.selectedGroup];
const config = type.config?.[this.selectedNodeInnerIndex]?.input;
this.inputsPage.replaceChildren(
...items
.map((oldName) => {
let value = inputs[oldName];
if (!value) {
return;
}
return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false);
})
.filter(Boolean)
);
return !!items.length;
}
buildOutputsPage() {
const nodes = this.groupData.nodeData.nodes;
const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]);
const outputs = innerNodeDef?.output ?? [];
const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex];
const type = app.graph.extra.groupNodes[this.selectedGroup];
const config = type.config?.[this.selectedNodeInnerIndex]?.output;
const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex];
const checkable = node.type !== "PrimitiveNode";
this.outputsPage.replaceChildren(
...outputs
.map((type, slot) => {
const groupOutputIndex = groupOutputs?.[slot];
const oldName = innerNodeDef.output_name?.[slot] ?? type;
let value = config?.[slot]?.name;
const visible = config?.[slot]?.visible || groupOutputIndex != null;
if (!value || value === oldName) {
value = "";
}
return this.getEditElement("output", slot, value, oldName, visible, checkable);
})
.filter(Boolean)
);
return !!outputs.length;
}
show(type) {
const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b));
this.innerNodesList = $el("ul.comfy-group-manage-list-items");
this.widgetsPage = $el("section.comfy-group-manage-node-page");
this.inputsPage = $el("section.comfy-group-manage-node-page");
this.outputsPage = $el("section.comfy-group-manage-node-page");
const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]);
this.tabs = [
["Inputs", this.inputsPage],
["Widgets", this.widgetsPage],
["Outputs", this.outputsPage],
].reduce((p, [name, page]) => {
p[name] = {
tab: $el("a", {
onclick: () => {
this.changeTab(name);
},
textContent: name,
}),
page,
};
return p;
}, {});
const outer = $el("div.comfy-group-manage-outer", [
$el("header", [
$el("h2", "Group Nodes"),
$el(
"select",
{
onchange: (e) => {
this.changeGroup(e.target.value);
},
},
groupNodes.map((g) =>
$el("option", {
textContent: g,
selected: "workflow/" + g === type,
value: g,
})
)
),
]),
$el("main", [
$el("section.comfy-group-manage-list", this.innerNodesList),
$el("section.comfy-group-manage-node", [
$el(
"header",
Object.values(this.tabs).map((t) => t.tab)
),
pages,
]),
]),
$el("footer", [
$el(
"button.comfy-btn",
{
onclick: (e) => {
const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup);
if (node) {
alert("This group node is in use in the current workflow, please first remove these.");
return;
}
if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) {
delete app.graph.extra.groupNodes[this.selectedGroup];
LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup);
}
this.show();
},
},
"Delete Group Node"
),
$el(
"button.comfy-btn",
{
onclick: async () => {
let nodesByType;
let recreateNodes = [];
const types = {};
for (const g in this.modifications) {
const type = app.graph.extra.groupNodes[g];
let config = (type.config ??= {});
let nodeMods = this.modifications[g]?.nodes;
if (nodeMods) {
const keys = Object.keys(nodeMods);
if (nodeMods[keys[0]][ORDER]) {
// If any node is reordered, they will all need sequencing
const orderedNodes = [];
const orderedMods = {};
const orderedConfig = {};
for (const n of keys) {
const order = nodeMods[n][ORDER].order;
orderedNodes[order] = type.nodes[+n];
orderedMods[order] = nodeMods[n];
orderedNodes[order].index = order;
}
// Rewrite links
for (const l of type.links) {
if (l[0] != null) l[0] = type.nodes[l[0]].index;
if (l[2] != null) l[2] = type.nodes[l[2]].index;
}
// Rewrite externals
if (type.external) {
for (const ext of type.external) {
ext[0] = type.nodes[ext[0]];
}
}
// Rewrite modifications
for (const id of keys) {
if (config[id]) {
orderedConfig[type.nodes[id].index] = config[id];
}
delete config[id];
}
type.nodes = orderedNodes;
nodeMods = orderedMods;
type.config = config = orderedConfig;
}
merge(config, nodeMods);
}
types[g] = type;
if (!nodesByType) {
nodesByType = app.graph._nodes.reduce((p, n) => {
p[n.type] ??= [];
p[n.type].push(n);
return p;
}, {});
}
const nodes = nodesByType["workflow/" + g];
if (nodes) recreateNodes.push(...nodes);
}
await GroupNodeConfig.registerFromWorkflow(types, {});
for (const node of recreateNodes) {
node.recreate();
}
this.modifications = {};
this.app.graph.setDirtyCanvas(true, true);
this.changeGroup(this.selectedGroup, false);
},
},
"Save"
),
$el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"),
]),
]);
this.element.replaceChildren(outer);
this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]);
this.element.showModal();
this.element.addEventListener("close", () => {
this.draggable?.dispose();
});
}
}

View File

@ -1,4 +1,5 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle
const id = "Comfy.NodeTemplates";
const file = "comfy.templates.json";
class ManageTemplates extends ComfyDialog {
constructor() {
super();
this.load().then((v) => {
this.templates = v;
});
this.element.classList.add("comfy-manage-templates");
this.templates = this.load();
this.draggedEl = null;
this.saveVisualCue = null;
this.emptyImg = new Image();
this.emptyImg.src = '';
this.emptyImg.src = "";
this.importInput = $el("input", {
type: "file",
@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog {
return btns;
}
load() {
const templates = localStorage.getItem(id);
if (templates) {
return JSON.parse(templates);
async load() {
let templates = [];
if (app.storageLocation === "server") {
if (app.isNewUserSession) {
// New user so migrate existing templates
const json = localStorage.getItem(id);
if (json) {
templates = JSON.parse(json);
}
await api.storeUserData(file, json, { stringify: false });
} else {
const res = await api.getUserData(file);
if (res.status === 200) {
try {
templates = await res.json();
} catch (error) {
}
} else if (res.status !== 404) {
console.error(res.status + " " + res.statusText);
}
}
} else {
return [];
const json = localStorage.getItem(id);
if (json) {
templates = JSON.parse(json);
}
}
return templates ?? [];
}
store() {
localStorage.setItem(id, JSON.stringify(this.templates));
async store() {
if(app.storageLocation === "server") {
const templates = JSON.stringify(this.templates, undefined, 4);
localStorage.setItem(id, templates); // Backwards compatibility
try {
await api.storeUserData(file, templates, { stringify: false });
} catch (error) {
console.error(error);
alert(error.message);
}
} else {
localStorage.setItem(id, JSON.stringify(this.templates));
}
}
async importAll() {
@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog {
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
var importFile = JSON.parse(reader.result);
if (importFile && importFile?.templates) {
const importFile = JSON.parse(reader.result);
if (importFile?.templates) {
for (const template of importFile.templates) {
if (template?.name && template?.data) {
this.templates.push(template);
}
}
this.store();
await this.store();
}
};
await reader.readAsText(file);
@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog {
e.currentTarget.style.border = "1px dashed transparent";
e.currentTarget.removeAttribute("draggable");
// rearrange the elements in the localStorage
// rearrange the elements
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
var prev_i = el.dataset.id;

View File

@ -1,4 +1,5 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"
const MAX_HISTORY = 50;
@ -15,6 +16,7 @@ function checkState() {
}
activeState = clone(currentState);
redo.length = 0;
api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState }));
}
}
@ -92,7 +94,7 @@ const undoRedo = async (e) => {
};
const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) {
const listener = () => {
@ -106,15 +108,23 @@ const bindInput = (activeEl) => {
}
};
let keyIgnored = false;
window.addEventListener(
"keydown",
(e) => {
requestAnimationFrame(async () => {
const activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
let activeEl;
// If we are auto queue in change mode then we do want to trigger on inputs
if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") {
activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
}
keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta";
if (keyIgnored) return;
// Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return;
@ -127,11 +137,23 @@ window.addEventListener(
true
);
window.addEventListener("keyup", (e) => {
if (keyIgnored) {
keyIgnored = false;
checkState();
}
});
// Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => {
checkState();
});
// Handle prompt queue event for dynamic widget changes
api.addEventListener("promptQueued", () => {
checkState();
});
// Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) {
@ -145,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) {
checkState();
return v;
};
// Handle litegraph context menu for COMBO widgets
const close = LiteGraph.ContextMenu.prototype.close;
LiteGraph.ContextMenu.prototype.close = function(e) {
const v = close.apply(this, arguments);
checkState();
return v;
}

View File

@ -16,5 +16,33 @@
window.graph = app.graph;
</script>
</head>
<body class="litegraph"></body>
<body class="litegraph">
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
</div>
</body>
</html>

View File

@ -3,7 +3,8 @@
"baseUrl": ".",
"paths": {
"/*": ["./*"]
}
},
"lib": ["DOM", "ES2022"]
},
"include": ["."]
}

View File

@ -11496,7 +11496,7 @@ LGraphNode.prototype.executeAction = function(action)
}
timeout_close = setTimeout(function() {
dialog.close();
}, 500);
}, typeof options.hide_on_mouse_leave === "number" ? options.hide_on_mouse_leave : 500);
});
// if filtering, check focus changed to comboboxes and prevent closing
if (options.do_type_filter){

View File

@ -12,6 +12,13 @@ class ComfyApi extends EventTarget {
}
fetchApi(route, options) {
if (!options) {
options = {};
}
if (!options.headers) {
options.headers = {};
}
options.headers["Comfy-User"] = this.user;
return fetch(this.apiURL(route), options);
}
@ -315,6 +322,99 @@ class ComfyApi extends EventTarget {
async interrupt() {
await this.#postItem("interrupt", null);
}
/**
* Gets user configuration data and where data should be stored
* @returns { Promise<{ storage: "server" | "browser", users?: Promise<string, unknown>, migrated?: boolean }> }
*/
async getUserConfig() {
return (await this.fetchApi("/users")).json();
}
/**
* Creates a new user
* @param { string } username
* @returns The fetch response
*/
createUser(username) {
return this.fetchApi("/users", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ username }),
});
}
/**
* Gets all setting values for the current user
* @returns { Promise<string, unknown> } A dictionary of id -> value
*/
async getSettings() {
return (await this.fetchApi("/settings")).json();
}
/**
* Gets a setting for the current user
* @param { string } id The id of the setting to fetch
* @returns { Promise<unknown> } The setting value
*/
async getSetting(id) {
return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json();
}
/**
* Stores a dictionary of settings for the current user
* @param { Record<string, unknown> } settings Dictionary of setting id -> value to save
* @returns { Promise<void> }
*/
async storeSettings(settings) {
return this.fetchApi(`/settings`, {
method: "POST",
body: JSON.stringify(settings)
});
}
/**
* Stores a setting for the current user
* @param { string } id The id of the setting to update
* @param { unknown } value The value of the setting
* @returns { Promise<void> }
*/
async storeSetting(id, value) {
return this.fetchApi(`/settings/${encodeURIComponent(id)}`, {
method: "POST",
body: JSON.stringify(value)
});
}
/**
* Gets a user data file for the current user
* @param { string } file The name of the userdata file to load
* @param { RequestInit } [options]
* @returns { Promise<unknown> } The fetch response object
*/
async getUserData(file, options) {
return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options);
}
/**
* Stores a user data file for the current user
* @param { string } file The name of the userdata file to save
* @param { unknown } data The data to save to the file
* @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options]
* @returns { Promise<void> }
*/
async storeUserData(file, data, options = { stringify: true, throwOnError: true }) {
const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, {
method: "POST",
body: options?.stringify ? JSON.stringify(data) : data,
...options,
});
if (resp.status !== 200) {
throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`);
}
}
}
export const api = new ComfyApi();

View File

@ -1,5 +1,5 @@
import { ComfyLogging } from "./logging.js";
import { ComfyWidgets } from "./widgets.js";
import { ComfyWidgets, initWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
@ -269,6 +269,71 @@ export class ComfyApp {
* @param {*} node The node to add the menu handler
*/
#addNodeContextMenuHandler(node) {
function getCopyImageOption(img) {
if (typeof window.ClipboardItem === "undefined") return [];
return [
{
content: "Copy Image",
callback: async () => {
const url = new URL(img.src);
url.searchParams.delete("preview");
const writeImage = async (blob) => {
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob,
}),
]);
};
try {
const data = await fetch(url);
const blob = await data.blob();
try {
await writeImage(blob);
} catch (error) {
// Chrome seems to only support PNG on write, convert and try again
if (blob.type !== "image/png") {
const canvas = $el("canvas", {
width: img.naturalWidth,
height: img.naturalHeight,
});
const ctx = canvas.getContext("2d");
let image;
if (typeof window.createImageBitmap === "undefined") {
image = new Image();
const p = new Promise((resolve, reject) => {
image.onload = resolve;
image.onerror = reject;
}).finally(() => {
URL.revokeObjectURL(image.src);
});
image.src = URL.createObjectURL(blob);
await p;
} else {
image = await createImageBitmap(blob);
}
try {
ctx.drawImage(image, 0, 0);
canvas.toBlob(writeImage, "image/png");
} finally {
if (typeof image.close === "function") {
image.close();
}
}
return;
}
throw error;
}
} catch (error) {
alert("Error copying image: " + (error.message ?? error));
}
},
},
];
}
node.prototype.getExtraMenuOptions = function (_, options) {
if (this.imgs) {
// If this node has images then we add an open in new tab item
@ -286,16 +351,17 @@ export class ComfyApp {
content: "Open Image",
callback: () => {
let url = new URL(img.src);
url.searchParams.delete('preview');
window.open(url, "_blank")
url.searchParams.delete("preview");
window.open(url, "_blank");
},
},
...getCopyImageOption(img),
{
content: "Save Image",
callback: () => {
const a = document.createElement("a");
let url = new URL(img.src);
url.searchParams.delete('preview');
url.searchParams.delete("preview");
a.href = url;
a.setAttribute("download", new URLSearchParams(url.search).get("filename"));
document.body.append(a);
@ -308,33 +374,41 @@ export class ComfyApp {
}
options.push({
content: "Bypass",
callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); }
});
content: "Bypass",
callback: (obj) => {
if (this.mode === 4) this.mode = 0;
else this.mode = 4;
this.graph.change();
},
});
// prevent conflict of clipspace content
if(!ComfyApp.clipspace_return_node) {
if (!ComfyApp.clipspace_return_node) {
options.push({
content: "Copy (Clipspace)",
callback: (obj) => { ComfyApp.copyToClipspace(this); }
});
content: "Copy (Clipspace)",
callback: (obj) => {
ComfyApp.copyToClipspace(this);
},
});
if(ComfyApp.clipspace != null) {
if (ComfyApp.clipspace != null) {
options.push({
content: "Paste (Clipspace)",
callback: () => { ComfyApp.pasteFromClipspace(this); }
});
content: "Paste (Clipspace)",
callback: () => {
ComfyApp.pasteFromClipspace(this);
},
});
}
if(ComfyApp.isImageNode(this)) {
if (ComfyApp.isImageNode(this)) {
options.push({
content: "Open in MaskEditor",
callback: (obj) => {
ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor();
}
});
content: "Open in MaskEditor",
callback: (obj) => {
ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor();
},
});
}
}
};
@ -1291,10 +1365,92 @@ export class ComfyApp {
await Promise.all(extensionPromises);
}
async #migrateSettings() {
this.isNewUserSession = true;
// Store all current settings
const settings = Object.keys(this.ui.settings).reduce((p, n) => {
const v = localStorage[`Comfy.Settings.${n}`];
if (v) {
try {
p[n] = JSON.parse(v);
} catch (error) {}
}
return p;
}, {});
await api.storeSettings(settings);
}
async #setUser() {
const userConfig = await api.getUserConfig();
this.storageLocation = userConfig.storage;
if (typeof userConfig.migrated == "boolean") {
// Single user mode migrated true/false for if the default user is created
if (!userConfig.migrated && this.storageLocation === "server") {
// Default user not created yet
await this.#migrateSettings();
}
return;
}
this.multiUserServer = true;
let user = localStorage["Comfy.userId"];
const users = userConfig.users ?? {};
if (!user || !users[user]) {
// This will rarely be hit so move the loading to on demand
const { UserSelectionScreen } = await import("./ui/userSelection.js");
this.ui.menuContainer.style.display = "none";
const { userId, username, created } = await new UserSelectionScreen().show(users, user);
this.ui.menuContainer.style.display = "";
user = userId;
localStorage["Comfy.userName"] = username;
localStorage["Comfy.userId"] = user;
if (created) {
api.user = user;
await this.#migrateSettings();
}
}
api.user = user;
this.ui.settings.addSetting({
id: "Comfy.SwitchUser",
name: "Switch User",
type: (name) => {
let currentUser = localStorage["Comfy.userName"];
if (currentUser) {
currentUser = ` (${currentUser})`;
}
return $el("tr", [
$el("td", [
$el("label", {
textContent: name,
}),
]),
$el("td", [
$el("button", {
textContent: name + (currentUser ?? ""),
onclick: () => {
delete localStorage["Comfy.userId"];
delete localStorage["Comfy.userName"];
window.location.reload();
},
}),
]),
]);
},
});
}
/**
* Set up the app on the page
*/
async setup() {
await this.#setUser();
await this.ui.settings.load();
await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM
@ -1338,6 +1494,7 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("init");
await this.registerNodes();
initWidgets(this);
// Load previous workflow
let restored = false;
@ -1692,6 +1849,14 @@ export class ComfyApp {
*/
async graphToPrompt() {
for (const outerNode of this.graph.computeExecutionOrder(false)) {
if (outerNode.widgets) {
for (const widget of outerNode.widgets) {
// Allow widgets to run callbacks before a prompt has been queued
// e.g. random seed before every gen
widget.beforeQueued?.();
}
}
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) {
@ -1903,6 +2068,7 @@ export class ComfyApp {
} finally {
this.#processingQueue = false;
}
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
}
/**
@ -2046,6 +2212,8 @@ export class ComfyApp {
}
}
}
await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
}
/**

View File

@ -239,7 +239,8 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 ||
widget.type === "converted-widget";
widget.type === "converted-widget"||
widget.type === "hidden";
element.hidden = hidden;
element.style.display = hidden ? "none" : null;
if (hidden) {

View File

@ -269,6 +269,9 @@ export class ComfyLogging {
id: settingId,
name: settingId,
defaultValue: true,
onChange: (value) => {
this.enabled = value;
},
type: (name, setter, value) => {
return $el("tr", [
$el("td", [
@ -283,7 +286,7 @@ export class ComfyLogging {
type: "checkbox",
checked: value,
onchange: (event) => {
setter((this.enabled = event.target.checked));
setter(event.target.checked);
},
}),
$el("button", {

View File

@ -1,5 +1,23 @@
import {api} from "./api.js";
import { api } from "./api.js";
import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js";
import { toggleSwitch } from "./ui/toggleSwitch.js";
import { ComfySettingsDialog } from "./ui/settings.js";
export const ComfyDialog = _ComfyDialog;
/**
*
* @param { string } tag HTML Element Tag and optional classes e.g. div.class1.class2
* @param { string | Element | Element[] | {
* parent?: Element,
* $?: (el: Element) => void,
* dataset?: DOMStringMap,
* style?: CSSStyleDeclaration,
* for?: string
* } | undefined } propsOrChildren
* @param { Element[] | undefined } [children]
* @returns
*/
export function $el(tag, propsOrChildren, children) {
const split = tag.split(".");
const element = document.createElement(split.shift());
@ -8,6 +26,11 @@ export function $el(tag, propsOrChildren, children) {
}
if (propsOrChildren) {
if (typeof propsOrChildren === "string") {
propsOrChildren = { textContent: propsOrChildren };
} else if (propsOrChildren instanceof Element) {
propsOrChildren = [propsOrChildren];
}
if (Array.isArray(propsOrChildren)) {
element.append(...propsOrChildren);
} else {
@ -31,7 +54,7 @@ export function $el(tag, propsOrChildren, children) {
Object.assign(element, propsOrChildren);
if (children) {
element.append(...children);
element.append(...(children instanceof Array ? children : [children]));
}
if (parent) {
@ -167,267 +190,6 @@ function dragElement(dragEl, settings) {
}
}
export class ComfyDialog {
constructor() {
this.element = $el("div.comfy-modal", {parent: document.body}, [
$el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]),
]);
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Close",
onclick: () => this.close(),
}),
];
}
close() {
this.element.style.display = "none";
}
show(html) {
if (typeof html === "string") {
this.textElement.innerHTML = html;
} else {
this.textElement.replaceChildren(html);
}
this.element.style.display = "flex";
}
}
class ComfySettingsDialog extends ComfyDialog {
constructor() {
super();
this.element = $el("dialog", {
id: "comfy-settings-dialog",
parent: document.body,
}, [
$el("table.comfy-modal-content.comfy-table", [
$el("caption", {textContent: "Settings"}),
$el("tbody", {$: (tbody) => (this.textElement = tbody)}),
$el("button", {
type: "button",
textContent: "Close",
style: {
cursor: "pointer",
},
onclick: () => {
this.element.close();
},
}),
]),
]);
this.settings = [];
}
getSettingValue(id, defaultValue) {
const settingId = "Comfy.Settings." + id;
const v = localStorage[settingId];
return v == null ? defaultValue : JSON.parse(v);
}
setSettingValue(id, value) {
const settingId = "Comfy.Settings." + id;
localStorage[settingId] = JSON.stringify(value);
}
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
if (!id) {
throw new Error("Settings must have an ID");
}
if (this.settings.find((s) => s.id === id)) {
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
}
const settingId = `Comfy.Settings.${id}`;
const v = localStorage[settingId];
let value = v == null ? defaultValue : JSON.parse(v);
// Trigger initial setting of value
if (onChange) {
onChange(value, undefined);
}
this.settings.push({
render: () => {
const setter = (v) => {
if (onChange) {
onChange(v, value);
}
localStorage[settingId] = JSON.stringify(v);
value = v;
};
value = this.getSettingValue(id, defaultValue);
let element;
const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name,
})
]);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
id: htmlID,
type: "checkbox",
checked: value,
onchange: (event) => {
const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked)
}
this.setSettingValue(id, isChecked);
},
}),
]),
])
break;
case "number":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
type,
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
case "slider":
element = $el("tr", [
labelCell,
$el("td", [
$el("div", {
style: {
display: "grid",
gridAutoFlow: "column",
},
}, [
$el("input", {
...attrs,
value,
type: "range",
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
}),
$el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: {maxWidth: "4rem"},
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]),
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
console.warn(`Unsupported setting type '${type}, defaulting to text`);
}
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
}
}
if (tooltip) {
element.title = tooltip;
}
return element;
},
});
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
this.textElement.replaceChildren(
$el("tr", {
style: {display: "none"},
}, [
$el("th"),
$el("th", {style: {width: "33%"}})
]),
...this.settings.map((s) => s.render()),
)
this.element.showModal();
}
}
class ComfyList {
#type;
#text;
@ -526,7 +288,7 @@ export class ComfyUI {
constructor(app) {
this.app = app;
this.dialog = new ComfyDialog();
this.settings = new ComfySettingsDialog();
this.settings = new ComfySettingsDialog(app);
this.batchCount = 1;
this.lastQueueSize = 0;
@ -607,6 +369,31 @@ export class ComfyUI {
},
});
const autoQueueModeEl = toggleSwitch(
"autoQueueMode",
[
{ text: "instant", tooltip: "A new prompt will be queued as soon as the queue reaches 0" },
{ text: "change", tooltip: "A new prompt will be queued when the queue is at 0 and the graph is/has changed" },
],
{
onChange: (value) => {
this.autoQueueMode = value.item.value;
},
}
);
autoQueueModeEl.style.display = "none";
api.addEventListener("graphChanged", () => {
if (this.autoQueueMode === "change" && this.autoQueueEnabled === true) {
if (this.lastQueueSize === 0) {
this.graphHasChanged = false;
app.queuePrompt(0, this.batchCount);
} else {
this.graphHasChanged = true;
}
}
});
this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
$el("div.drag-handle", {
style: {
@ -633,6 +420,7 @@ export class ComfyUI {
document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
document.getElementById("autoQueueCheckbox").checked = false;
this.autoQueueEnabled = false;
},
}),
]),
@ -664,20 +452,22 @@ export class ComfyUI {
},
}),
]),
$el("div",[
$el("label",{
for:"autoQueueCheckbox",
innerHTML: "Auto Queue"
// textContent: "Auto Queue"
}),
$el("input", {
id: "autoQueueCheckbox",
type: "checkbox",
checked: false,
title: "Automatically queue prompt when the queue size hits 0",
onchange: (e) => {
this.autoQueueEnabled = e.target.checked;
autoQueueModeEl.style.display = this.autoQueueEnabled ? "" : "none";
}
}),
autoQueueModeEl
])
]),
$el("div.comfy-menu-btns", [
@ -811,10 +601,13 @@ export class ComfyUI {
if (
this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked &&
! app.lastExecutionError
this.autoQueueEnabled &&
(this.autoQueueMode === "instant" || this.graphHasChanged) &&
!app.lastExecutionError
) {
app.queuePrompt(0, this.batchCount);
status.exec_info.queue_remaining += this.batchCount;
this.graphHasChanged = false;
}
this.lastQueueSize = status.exec_info.queue_remaining;
}

View File

@ -0,0 +1,32 @@
import { $el } from "../ui.js";
export class ComfyDialog {
constructor() {
this.element = $el("div.comfy-modal", { parent: document.body }, [
$el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]),
]);
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Close",
onclick: () => this.close(),
}),
];
}
close() {
this.element.style.display = "none";
}
show(html) {
if (typeof html === "string") {
this.textElement.innerHTML = html;
} else {
this.textElement.replaceChildren(html);
}
this.element.style.display = "flex";
}
}

View File

@ -0,0 +1,287 @@
// @ts-check
/*
Original implementation:
https://github.com/TahaSh/drag-to-reorder
MIT License
Copyright (c) 2023 Taha Shashtari
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
import { $el } from "../ui.js";
$el("style", {
parent: document.head,
textContent: `
.draggable-item {
position: relative;
will-change: transform;
user-select: none;
}
.draggable-item.is-idle {
transition: 0.25s ease transform;
}
.draggable-item.is-draggable {
z-index: 10;
}
`
});
export class DraggableList extends EventTarget {
listContainer;
draggableItem;
pointerStartX;
pointerStartY;
scrollYMax;
itemsGap = 0;
items = [];
itemSelector;
handleClass = "drag-handle";
off = [];
offDrag = [];
constructor(element, itemSelector) {
super();
this.listContainer = element;
this.itemSelector = itemSelector;
if (!this.listContainer) return;
this.off.push(this.on(this.listContainer, "mousedown", this.dragStart));
this.off.push(this.on(this.listContainer, "touchstart", this.dragStart));
this.off.push(this.on(document, "mouseup", this.dragEnd));
this.off.push(this.on(document, "touchend", this.dragEnd));
}
getAllItems() {
if (!this.items?.length) {
this.items = Array.from(this.listContainer.querySelectorAll(this.itemSelector));
this.items.forEach((element) => {
element.classList.add("is-idle");
});
}
return this.items;
}
getIdleItems() {
return this.getAllItems().filter((item) => item.classList.contains("is-idle"));
}
isItemAbove(item) {
return item.hasAttribute("data-is-above");
}
isItemToggled(item) {
return item.hasAttribute("data-is-toggled");
}
on(source, event, listener, options) {
listener = listener.bind(this);
source.addEventListener(event, listener, options);
return () => source.removeEventListener(event, listener);
}
dragStart(e) {
if (e.target.classList.contains(this.handleClass)) {
this.draggableItem = e.target.closest(this.itemSelector);
}
if (!this.draggableItem) return;
this.pointerStartX = e.clientX || e.touches[0].clientX;
this.pointerStartY = e.clientY || e.touches[0].clientY;
this.scrollYMax = this.listContainer.scrollHeight - this.listContainer.clientHeight;
this.setItemsGap();
this.initDraggableItem();
this.initItemsState();
this.offDrag.push(this.on(document, "mousemove", this.drag));
this.offDrag.push(this.on(document, "touchmove", this.drag, { passive: false }));
this.dispatchEvent(
new CustomEvent("dragstart", {
detail: { element: this.draggableItem, position: this.getAllItems().indexOf(this.draggableItem) },
})
);
}
setItemsGap() {
if (this.getIdleItems().length <= 1) {
this.itemsGap = 0;
return;
}
const item1 = this.getIdleItems()[0];
const item2 = this.getIdleItems()[1];
const item1Rect = item1.getBoundingClientRect();
const item2Rect = item2.getBoundingClientRect();
this.itemsGap = Math.abs(item1Rect.bottom - item2Rect.top);
}
initItemsState() {
this.getIdleItems().forEach((item, i) => {
if (this.getAllItems().indexOf(this.draggableItem) > i) {
item.dataset.isAbove = "";
}
});
}
initDraggableItem() {
this.draggableItem.classList.remove("is-idle");
this.draggableItem.classList.add("is-draggable");
}
drag(e) {
if (!this.draggableItem) return;
e.preventDefault();
const clientX = e.clientX || e.touches[0].clientX;
const clientY = e.clientY || e.touches[0].clientY;
const listRect = this.listContainer.getBoundingClientRect();
if (clientY > listRect.bottom) {
if (this.listContainer.scrollTop < this.scrollYMax) {
this.listContainer.scrollBy(0, 10);
this.pointerStartY -= 10;
}
} else if (clientY < listRect.top && this.listContainer.scrollTop > 0) {
this.pointerStartY += 10;
this.listContainer.scrollBy(0, -10);
}
const pointerOffsetX = clientX - this.pointerStartX;
const pointerOffsetY = clientY - this.pointerStartY;
this.updateIdleItemsStateAndPosition();
this.draggableItem.style.transform = `translate(${pointerOffsetX}px, ${pointerOffsetY}px)`;
}
updateIdleItemsStateAndPosition() {
const draggableItemRect = this.draggableItem.getBoundingClientRect();
const draggableItemY = draggableItemRect.top + draggableItemRect.height / 2;
// Update state
this.getIdleItems().forEach((item) => {
const itemRect = item.getBoundingClientRect();
const itemY = itemRect.top + itemRect.height / 2;
if (this.isItemAbove(item)) {
if (draggableItemY <= itemY) {
item.dataset.isToggled = "";
} else {
delete item.dataset.isToggled;
}
} else {
if (draggableItemY >= itemY) {
item.dataset.isToggled = "";
} else {
delete item.dataset.isToggled;
}
}
});
// Update position
this.getIdleItems().forEach((item) => {
if (this.isItemToggled(item)) {
const direction = this.isItemAbove(item) ? 1 : -1;
item.style.transform = `translateY(${direction * (draggableItemRect.height + this.itemsGap)}px)`;
} else {
item.style.transform = "";
}
});
}
dragEnd() {
if (!this.draggableItem) return;
this.applyNewItemsOrder();
this.cleanup();
}
applyNewItemsOrder() {
const reorderedItems = [];
let oldPosition = -1;
this.getAllItems().forEach((item, index) => {
if (item === this.draggableItem) {
oldPosition = index;
return;
}
if (!this.isItemToggled(item)) {
reorderedItems[index] = item;
return;
}
const newIndex = this.isItemAbove(item) ? index + 1 : index - 1;
reorderedItems[newIndex] = item;
});
for (let index = 0; index < this.getAllItems().length; index++) {
const item = reorderedItems[index];
if (typeof item === "undefined") {
reorderedItems[index] = this.draggableItem;
}
}
reorderedItems.forEach((item) => {
this.listContainer.appendChild(item);
});
this.items = reorderedItems;
this.dispatchEvent(
new CustomEvent("dragend", {
detail: { element: this.draggableItem, oldPosition, newPosition: reorderedItems.indexOf(this.draggableItem) },
})
);
}
cleanup() {
this.itemsGap = 0;
this.items = [];
this.unsetDraggableItem();
this.unsetItemState();
this.offDrag.forEach((f) => f());
this.offDrag = [];
}
unsetDraggableItem() {
this.draggableItem.style = null;
this.draggableItem.classList.remove("is-draggable");
this.draggableItem.classList.add("is-idle");
this.draggableItem = null;
}
unsetItemState() {
this.getIdleItems().forEach((item, i) => {
delete item.dataset.isAbove;
delete item.dataset.isToggled;
item.style.transform = "";
});
}
dispose() {
this.off.forEach((f) => f());
}
}

View File

@ -0,0 +1,307 @@
import { $el } from "../ui.js";
import { api } from "../api.js";
import { ComfyDialog } from "./dialog.js";
export class ComfySettingsDialog extends ComfyDialog {
constructor(app) {
super();
this.app = app;
this.settingsValues = {};
this.settingsLookup = {};
this.element = $el(
"dialog",
{
id: "comfy-settings-dialog",
parent: document.body,
},
[
$el("table.comfy-modal-content.comfy-table", [
$el("caption", { textContent: "Settings" }),
$el("tbody", { $: (tbody) => (this.textElement = tbody) }),
$el("button", {
type: "button",
textContent: "Close",
style: {
cursor: "pointer",
},
onclick: () => {
this.element.close();
},
}),
]),
]
);
}
get settings() {
return Object.values(this.settingsLookup);
}
async load() {
if (this.app.storageLocation === "browser") {
this.settingsValues = localStorage;
} else {
this.settingsValues = await api.getSettings();
}
// Trigger onChange for any settings added before load
for (const id in this.settingsLookup) {
this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]);
}
}
getId(id) {
if (this.app.storageLocation === "browser") {
id = "Comfy.Settings." + id;
}
return id;
}
getSettingValue(id, defaultValue) {
let value = this.settingsValues[this.getId(id)];
if(value != null) {
if(this.app.storageLocation === "browser") {
try {
value = JSON.parse(value);
} catch (error) {
}
}
}
return value ?? defaultValue;
}
async setSettingValueAsync(id, value) {
const json = JSON.stringify(value);
localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage
let oldValue = this.getSettingValue(id, undefined);
this.settingsValues[this.getId(id)] = value;
if (id in this.settingsLookup) {
this.settingsLookup[id].onChange?.(value, oldValue);
}
await api.storeSetting(id, value);
}
setSettingValue(id, value) {
this.setSettingValueAsync(id, value).catch((err) => {
alert(`Error saving setting '${id}'`);
console.error(err);
});
}
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) {
if (!id) {
throw new Error("Settings must have an ID");
}
if (id in this.settingsLookup) {
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
}
let skipOnChange = false;
let value = this.getSettingValue(id);
if (value == null) {
if (this.app.isNewUserSession) {
// Check if we have a localStorage value but not a setting value and we are a new user
const localValue = localStorage["Comfy.Settings." + id];
if (localValue) {
value = JSON.parse(localValue);
this.setSettingValue(id, value); // Store on the server
}
}
if (value == null) {
value = defaultValue;
}
}
// Trigger initial setting of value
if (!skipOnChange) {
onChange?.(value, undefined);
}
this.settingsLookup[id] = {
id,
onChange,
name,
render: () => {
const setter = (v) => {
if (onChange) {
onChange(v, value);
}
this.setSettingValue(id, v);
value = v;
};
value = this.getSettingValue(id, defaultValue);
let element;
const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name,
}),
]);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
id: htmlID,
type: "checkbox",
checked: value,
onchange: (event) => {
const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked);
}
this.setSettingValue(id, isChecked);
},
}),
]),
]);
break;
case "number":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
type,
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
case "slider":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"div",
{
style: {
display: "grid",
gridAutoFlow: "column",
},
},
[
$el("input", {
...attrs,
value,
type: "range",
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
}),
$el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: { maxWidth: "4rem" },
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]
),
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
console.warn(`Unsupported setting type '${type}, defaulting to text`);
}
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
}
}
if (tooltip) {
element.title = tooltip;
}
return element;
},
};
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
this.textElement.replaceChildren(
$el(
"tr",
{
style: { display: "none" },
},
[$el("th"), $el("th", { style: { width: "33%" } })]
),
...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render())
);
this.element.showModal();
}
}

View File

@ -0,0 +1,34 @@
.lds-ring {
display: inline-block;
position: relative;
width: 1em;
height: 1em;
}
.lds-ring div {
box-sizing: border-box;
display: block;
position: absolute;
width: 100%;
height: 100%;
border: 0.15em solid #fff;
border-radius: 50%;
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
border-color: #fff transparent transparent transparent;
}
.lds-ring div:nth-child(1) {
animation-delay: -0.45s;
}
.lds-ring div:nth-child(2) {
animation-delay: -0.3s;
}
.lds-ring div:nth-child(3) {
animation-delay: -0.15s;
}
@keyframes lds-ring {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

View File

@ -0,0 +1,9 @@
import { addStylesheet } from "../utils.js";
addStylesheet(import.meta.url);
export function createSpinner() {
const div = document.createElement("div");
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
return div.firstElementChild;
}

View File

@ -0,0 +1,60 @@
import { $el } from "../ui.js";
/**
* @typedef { { text: string, value?: string, tooltip?: string } } ToggleSwitchItem
*/
/**
* Creates a toggle switch element
* @param { string } name
* @param { Array<string | ToggleSwitchItem } items
* @param { Object } [opts]
* @param { (e: { item: ToggleSwitchItem, prev?: ToggleSwitchItem }) => void } [opts.onChange]
*/
export function toggleSwitch(name, items, { onChange } = {}) {
let selectedIndex;
let elements;
function updateSelected(index) {
if (selectedIndex != null) {
elements[selectedIndex].classList.remove("comfy-toggle-selected");
}
onChange?.({ item: items[index], prev: selectedIndex == null ? undefined : items[selectedIndex] });
selectedIndex = index;
elements[selectedIndex].classList.add("comfy-toggle-selected");
}
elements = items.map((item, i) => {
if (typeof item === "string") item = { text: item };
if (!item.value) item.value = item.text;
const toggle = $el(
"label",
{
textContent: item.text,
title: item.tooltip ?? "",
},
$el("input", {
name,
type: "radio",
value: item.value ?? item.text,
checked: item.selected,
onchange: () => {
updateSelected(i);
},
})
);
if (item.selected) {
updateSelected(i);
}
return toggle;
});
const container = $el("div.comfy-toggle-switch", elements);
if (selectedIndex == null) {
elements[0].children[0].checked = true;
updateSelected(0);
}
return container;
}

View File

@ -0,0 +1,135 @@
.comfy-user-selection {
width: 100vw;
height: 100vh;
position: absolute;
top: 0;
left: 0;
z-index: 999;
display: flex;
align-items: center;
justify-content: center;
font-family: sans-serif;
background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color));
}
.comfy-user-selection-inner {
background: var(--comfy-menu-bg);
margin-top: -30vh;
padding: 20px 40px;
border-radius: 10px;
min-width: 365px;
position: relative;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.3);
}
.comfy-user-selection-inner form {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
.comfy-user-selection-inner h1 {
margin: 10px 0 30px 0;
font-weight: normal;
}
.comfy-user-selection-inner label {
display: flex;
flex-direction: column;
width: 100%;
}
.comfy-user-selection input,
.comfy-user-selection select {
background-color: var(--comfy-input-bg);
color: var(--input-text);
border: 0;
border-radius: 5px;
padding: 5px;
margin-top: 10px;
}
.comfy-user-selection input::placeholder {
color: var(--descrip-text);
opacity: 1;
}
.comfy-user-existing {
width: 100%;
}
.no-users .comfy-user-existing {
display: none;
}
.comfy-user-selection-inner .or-separator {
margin: 10px 0;
padding: 10px;
display: block;
text-align: center;
width: 100%;
color: var(--descrip-text);
}
.comfy-user-selection-inner .or-separator {
overflow: hidden;
text-align: center;
margin-left: -10px;
}
.comfy-user-selection-inner .or-separator::before,
.comfy-user-selection-inner .or-separator::after {
content: "";
background-color: var(--border-color);
position: relative;
height: 1px;
vertical-align: middle;
display: inline-block;
width: calc(50% - 20px);
top: -1px;
}
.comfy-user-selection-inner .or-separator::before {
right: 10px;
margin-left: -50%;
}
.comfy-user-selection-inner .or-separator::after {
left: 10px;
margin-right: -50%;
}
.comfy-user-selection-inner section {
width: 100%;
padding: 10px;
margin: -10px;
transition: background-color 0.2s;
}
.comfy-user-selection-inner section.selected {
background: var(--border-color);
border-radius: 5px;
}
.comfy-user-selection-inner footer {
display: flex;
flex-direction: column;
align-items: center;
margin-top: 20px;
}
.comfy-user-selection-inner .comfy-user-error {
color: var(--error-text);
margin-bottom: 10px;
}
.comfy-user-button-next {
font-size: 16px;
padding: 6px 10px;
width: 100px;
display: flex;
gap: 5px;
align-items: center;
justify-content: center;
}

View File

@ -0,0 +1,114 @@
import { api } from "../api.js";
import { $el } from "../ui.js";
import { addStylesheet } from "../utils.js";
import { createSpinner } from "./spinner.js";
export class UserSelectionScreen {
async show(users, user) {
// This will rarely be hit so move the loading to on demand
await addStylesheet(import.meta.url);
const userSelection = document.getElementById("comfy-user-selection");
userSelection.style.display = "";
return new Promise((resolve) => {
const input = userSelection.getElementsByTagName("input")[0];
const select = userSelection.getElementsByTagName("select")[0];
const inputSection = input.closest("section");
const selectSection = select.closest("section");
const form = userSelection.getElementsByTagName("form")[0];
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
const button = userSelection.getElementsByClassName("comfy-user-button-next")[0];
let inputActive = null;
input.addEventListener("focus", () => {
inputSection.classList.add("selected");
selectSection.classList.remove("selected");
inputActive = true;
});
select.addEventListener("focus", () => {
inputSection.classList.remove("selected");
selectSection.classList.add("selected");
inputActive = false;
select.style.color = "";
});
select.addEventListener("blur", () => {
if (!select.value) {
select.style.color = "var(--descrip-text)";
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (inputActive == null) {
error.textContent = "Please enter a username or select an existing user.";
} else if (inputActive) {
const username = input.value.trim();
if (!username) {
error.textContent = "Please enter a username.";
return;
}
// Create new user
input.disabled = select.disabled = input.readonly = select.readonly = true;
const spinner = createSpinner();
button.prepend(spinner);
try {
const resp = await api.createUser(username);
if (resp.status >= 300) {
let message = "Error creating user: " + resp.status + " " + resp.statusText;
try {
const res = await resp.json();
if(res.error) {
message = res.error;
}
} catch (error) {
}
throw new Error(message);
}
resolve({ username, userId: await resp.json(), created: true });
} catch (err) {
spinner.remove();
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
input.disabled = select.disabled = input.readonly = select.readonly = false;
return;
}
} else if (!select.value) {
error.textContent = "Please select an existing user.";
return;
} else {
resolve({ username: users[select.value], userId: select.value, created: false });
}
});
if (user) {
const name = localStorage["Comfy.userName"];
if (name) {
input.value = name;
}
}
if (input.value) {
// Focus the input, do this separately as sometimes browsers like to fill in the value
input.focus();
}
const userIds = Object.keys(users ?? {});
if (userIds.length) {
for (const u of userIds) {
$el("option", { textContent: users[u], value: u, parent: select });
}
select.style.color = "var(--descrip-text)";
if (select.value) {
// Focus the select, do this separately as sometimes browsers like to fill in the value
select.focus();
}
} else {
userSelection.classList.add("no-users");
input.focus();
}
}).then((r) => {
userSelection.remove();
return r;
});
}
}

View File

@ -1,3 +1,5 @@
import { $el } from "./ui.js";
// Simple date formatter
const parts = {
d: (d) => d.getDate(),
@ -65,3 +67,22 @@ export function applyTextReplacements(app, value) {
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
});
}
export async function addStylesheet(urlOrFile, relativeTo) {
return new Promise((res, rej) => {
let url;
if (urlOrFile.endsWith(".js")) {
url = urlOrFile.substr(0, urlOrFile.length - 2) + "css";
} else {
url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString();
}
$el("link", {
parent: document.head,
rel: "stylesheet",
type: "text/css",
href: url,
onload: res,
onerror: rej,
});
});
}

View File

@ -1,6 +1,19 @@
import { api } from "./api.js"
import "./domWidget.js";
let controlValueRunBefore = false;
export function updateControlWidgetLabel(widget) {
let replacement = "after";
let find = "before";
if (controlValueRunBefore) {
[find, replacement] = [replacement, find]
}
widget.label = (widget.label ?? widget.name).replace(find, replacement);
}
const IS_CONTROL_WIDGET = Symbol();
const HAS_EXECUTED = Symbol();
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"];
let { min, max, step, round} = inputData[1];
@ -62,6 +75,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
serialize: false, // Don't include this in prompt.
}
);
valueControl[IS_CONTROL_WIDGET] = true;
updateControlWidgetLabel(valueControl);
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
@ -76,10 +91,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
serialize: false, // Don't include this in prompt.
}
);
updateControlWidgetLabel(comboFilter);
widgets.push(comboFilter);
}
valueControl.afterQueued = () => {
const applyWidgetControl = () => {
var v = valueControl.value;
if (isCombo && v !== "fixed") {
@ -159,6 +176,23 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
targetWidget.callback(targetWidget.value);
}
};
valueControl.beforeQueued = () => {
if (controlValueRunBefore) {
// Don't run on first execution
if (valueControl[HAS_EXECUTED]) {
applyWidgetControl();
}
}
valueControl[HAS_EXECUTED] = true;
};
valueControl.afterQueued = () => {
if (!controlValueRunBefore) {
applyWidgetControl();
}
};
return widgets;
};
@ -224,6 +258,34 @@ function isSlider(display, app) {
return (display==="slider") ? "slider" : "number"
}
export function initWidgets(app) {
app.ui.settings.addSetting({
id: "Comfy.WidgetControlMode",
name: "Widget Value Control Mode",
type: "combo",
defaultValue: "after",
options: ["before", "after"],
tooltip: "Controls when widget values are updated (randomize/increment/decrement), either before the prompt is queued or after.",
onChange(value) {
controlValueRunBefore = value === "before";
for (const n of app.graph._nodes) {
if (!n.widgets) continue;
for (const w of n.widgets) {
if (w[IS_CONTROL_WIDGET]) {
updateControlWidgetLabel(w);
if (w.linkedWidgets) {
for (const l of w.linkedWidgets) {
updateControlWidgetLabel(l);
}
}
}
}
}
app.graph.setDirtyCanvas(true);
},
});
}
export const ComfyWidgets = {
"INT:seed": seedWidget,
"INT:noise_seed": seedWidget,

View File

@ -121,6 +121,8 @@ body {
width: 100%;
}
.comfy-toggle-switch,
.comfy-btn,
.comfy-menu > button,
.comfy-menu-btns button,
.comfy-menu .comfy-list button,
@ -133,6 +135,7 @@ body {
margin-top: 2px;
}
.comfy-btn:hover:not(:disabled),
.comfy-menu > button:hover,
.comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover,
@ -143,6 +146,12 @@ body {
}
.comfy-menu span.drag-handle {
position: absolute;
top: 0;
left: 0;
}
span.drag-handle {
width: 10px;
height: 20px;
display: inline-block;
@ -158,12 +167,9 @@ body {
letter-spacing: 2px;
color: var(--drag-text);
text-shadow: 1px 0 1px black;
position: absolute;
top: 0;
left: 0;
}
.comfy-menu span.drag-handle::after {
span.drag-handle::after {
content: '.. .. ..';
}
@ -429,6 +435,43 @@ dialog::backdrop {
margin-left: 5px;
}
.comfy-toggle-switch {
border-width: 2px;
display: flex;
background-color: var(--comfy-input-bg);
margin: 2px 0;
white-space: nowrap;
}
.comfy-toggle-switch label {
padding: 2px 0px 3px 6px;
flex: auto;
border-radius: 8px;
align-items: center;
display: flex;
justify-content: center;
}
.comfy-toggle-switch label:first-child {
border-top-left-radius: 8px;
border-bottom-left-radius: 8px;
}
.comfy-toggle-switch label:last-child {
border-top-right-radius: 8px;
border-bottom-right-radius: 8px;
}
.comfy-toggle-switch .comfy-toggle-selected {
background-color: var(--comfy-menu-bg);
}
#extraOptions {
padding: 4px;
background-color: var(--bg-color);
margin-bottom: 4px;
border-radius: 4px;
}
/* Search box */
.litegraph.litesearchbox {

View File

@ -1,4 +1,5 @@
from comfy import samplers
from comfy import model_management
from comfy import sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview
@ -26,9 +27,8 @@ class BasicScheduler:
if denoise < 1.0:
total_steps = int(steps/denoise)
inner_model = model.patch_model(patch_weights=False)
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
model.unpatch_model()
model_management.load_models_gpu([model])
sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )
@ -106,9 +106,8 @@ class SDTurboScheduler:
def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
inner_model = model.patch_model(patch_weights=False)
sigmas = inner_model.model_sampling.sigma(timesteps)
model.unpatch_model()
model_management.load_models_gpu([model])
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )

View File

@ -34,7 +34,7 @@ class FreeU:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
@ -73,7 +73,7 @@ class FreeU_V2:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]

View File

@ -32,29 +32,29 @@ class HyperTile:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8
self.temp = None
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
for i in range(max_depth + 1):
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
if model_chans in apply_to:
shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size)

View File

@ -122,10 +122,34 @@ class LatentBatch:
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
if 'batch_index' in samples_out:
samples_out.pop('batch_index')
elif seed_behavior == "fixed":
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
}

View File

@ -1,7 +1,7 @@
import numpy as np
import scipy.ndimage
import torch
import comfy.utils
from comfy import utils
from comfy.nodes.common import MAX_RESOLUTION
@ -11,7 +11,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
source = utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
@ -24,7 +24,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
else:
mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
mask = utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds

View File

@ -1,5 +1,5 @@
import torch
import comfy.utils
from comfy import utils
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@ -27,12 +27,12 @@ class PatchModelAddDownscale:
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
h = utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
h = utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()

View File

@ -1,6 +1,6 @@
from comfy import sd
from comfy import sd, utils
from comfy import model_base
import comfy.model_management
from comfy import model_management
from comfy.cmd import folder_paths
import json
@ -118,6 +118,48 @@ class ModelMergeBlocks:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -136,46 +178,7 @@ class CheckpointSave:
CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class CLIPSave:
@ -205,7 +208,7 @@ class CLIPSave:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()])
model_management.load_models_gpu([clip.load_model()])
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:
@ -229,9 +232,9 @@ class CLIPSave:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
current_clip_sd = utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
class VAESave:
@ -265,7 +268,7 @@ class VAESave:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = {

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from comfy.cmd import folder_paths
from comfy import clip_model
from comfy import clip_vision
from comfy import ops
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
VISION_CONFIG_DICT = {
"hidden_size": 1024,
"image_size": 224,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"hidden_act": "quick_gelu",
}
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=ops):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = operations.LayerNorm(in_dim)
self.fc1 = operations.Linear(in_dim, hidden_dim)
self.fc2 = operations.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class FuseModule(nn.Module):
def __init__(self, embed_dim, operations):
super().__init__()
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim)
def fuse_fn(self, prompt_embeds, id_embeds):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
def forward(
self,
prompt_embeds,
id_embeds,
class_tokens_mask,
) -> torch.Tensor:
# id_embeds shape: [b, max_num_inputs, 1, 2048]
id_embeds = id_embeds.to(prompt_embeds.dtype)
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
batch_size, max_num_inputs = id_embeds.shape[:2]
# seq_length: 77
seq_length = prompt_embeds.shape[1]
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
flat_id_embeds = id_embeds.view(
-1, id_embeds.shape[-2], id_embeds.shape[-1]
)
# valid_id_mask [b*max_num_inputs]
valid_id_mask = (
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
< num_inputs[:, None]
)
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
class_tokens_mask = class_tokens_mask.view(-1)
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
# slice out the image token embeddings
image_token_embeds = prompt_embeds[class_tokens_mask]
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
class PhotoMakerIDEncoder(clip_model.CLIPVisionModelProjection):
def __init__(self):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
dtype = comfy.model_management.text_encoder_dtype(self.load_device)
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, ops.manual_cast)
self.visual_projection_2 = ops.manual_cast.Linear(1024, 1280, bias=False)
self.fuse_module = FuseModule(2048, ops.manual_cast)
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[2]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
RETURN_TYPES = ("PHOTOMAKER",)
FUNCTION = "load_photomaker_model"
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
data = data["id_encoder"]
photomaker_model.load_state_dict(data)
return (photomaker_model,)
class PhotoMakerEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",),
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_photomaker"
CATEGORY = "_for_testing/photomaker"
def apply_photomaker(self, photomaker, image, clip, text):
special_token = "photomaker"
pixel_values = clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
try:
index = text.split(" ").index(special_token) + 1
except ValueError:
index = -1
tokens = clip.tokenize(text, return_word_ids=True)
out_tokens = {}
for k in tokens:
out_tokens[k] = []
for t in tokens[k]:
f = list(filter(lambda x: x[2] != index, t))
while len(f) < len(t):
f.append(t[-1])
out_tokens[k].append(f)
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
if index > 0:
token_index = index - 1
num_id_images = 1
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
else:
out = cond
return ([[out, {"pooled_output": pooled}]], )
NODE_CLASS_MAPPINGS = {
"PhotoMakerLoader": PhotoMakerLoader,
"PhotoMakerEncode": PhotoMakerEncode,
}

View File

@ -33,6 +33,7 @@ class Blend:
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')

View File

@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = math.ceil(math.sqrt(lh * lw / hw1))
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
@ -143,6 +143,8 @@ class SelfAttentionGuidance:
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)

View File

@ -1,5 +1,4 @@
import torch
import comfy.utils
from comfy.nodes.common import MAX_RESOLUTION
from comfy import utils
@ -49,13 +48,57 @@ class StableZero123_Conditioning:
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = []
for i in range(batch_size):
cam_embeds.append(camera_embeddings(elevation, azimuth))
elevation += elevation_batch_increment
azimuth += azimuth_batch_increment
cam_embeds = torch.cat(cam_embeds, dim=0)
cond = torch.cat([utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
}

View File

@ -3,6 +3,7 @@ import torch
import comfy.utils
import comfy.sd
from comfy.cmd import folder_paths
from . import nodes_model_merging
class ImageOnlyCheckpointLoader:
@ -78,10 +79,26 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -1,3 +1,4 @@
{
"presets": ["@babel/preset-env"]
"presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
}

View File

@ -11,6 +11,7 @@
"devDependencies": {
"@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0"
}
@ -2591,6 +2592,19 @@
"@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0"
}
},
"node_modules/babel-plugin-transform-import-meta": {
"version": "2.2.1",
"resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz",
"integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==",
"dev": true,
"dependencies": {
"@babel/template": "^7.4.4",
"tslib": "^2.4.0"
},
"peerDependencies": {
"@babel/core": "^7.10.0"
}
},
"node_modules/babel-preset-current-node-syntax": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz",
@ -5233,6 +5247,12 @@
"node": ">=12"
}
},
"node_modules/tslib": {
"version": "2.6.2",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==",
"dev": true
},
"node_modules/type-detect": {
"version": "4.0.8",
"resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz",

View File

@ -24,6 +24,7 @@
"devDependencies": {
"@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0"
}

View File

@ -0,0 +1,295 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("users", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
function expectNoUserScreen() {
// Ensure login isnt visible
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection["style"].display).toBe("none");
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
}
describe("multi-user", () => {
function mockAddStylesheet() {
const utils = require("../../comfy/web/scripts/utils");
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
}
async function waitForUserScreenShow() {
mockAddStylesheet();
// Wait for "show" to be called
const { UserSelectionScreen } = require("../../comfy/web/scripts/ui/userSelection");
let resolve, reject;
const fn = UserSelectionScreen.prototype.show;
const p = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
const res = fn(...args);
await new Promise(process.nextTick); // wait for promises to resolve
resolve();
return res;
});
// @ts-ignore
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
await p;
await new Promise(process.nextTick); // wait for promises to resolve
}
async function testUserScreen(onShown, users) {
if (!users) {
users = {};
}
const starting = start({
resetEnv: true,
userConfig: { storage: "server", users },
});
// Ensure no current user
expect(localStorage["Comfy.userId"]).toBeFalsy();
expect(localStorage["Comfy.userName"]).toBeFalsy();
await waitForUserScreenShow();
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection).toBeTruthy();
// Ensure login is visible
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
// Ensure menu is hidden
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).toBe("none");
const isCreate = await onShown(selection);
// Submit form
selection.querySelectorAll("form")[0].submit();
await new Promise(process.nextTick); // wait for promises to resolve
// Wait for start
const s = await starting;
// Ensure login is removed
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
// Ensure settings + templates are saved
const { api } = require("../../comfy/web/scripts/api");
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
if (isCreate) {
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(s.app.isNewUserSession).toBeTruthy();
} else {
expect(s.app.isNewUserSession).toBeFalsy();
}
return { users, selection, ...s };
}
it("allows user creation if no users", async () => {
const { users } = await testUserScreen((selection) => {
// Ensure we have no users flag added
expect(selection.classList.contains("no-users")).toBeTruthy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User";
return true;
});
expect(users).toStrictEqual({
"Test User!": "Test User",
});
expect(localStorage["Comfy.userId"]).toBe("Test User!");
expect(localStorage["Comfy.userName"]).toBe("Test User");
});
it("allows user creation if no current user but other users", async () => {
const users = {
"Test User 2!": "Test User 2",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User 3";
return true;
}, users);
expect(users).toStrictEqual({
"Test User 2!": "Test User 2",
"Test User 3!": "Test User 3",
});
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
});
it("allows user selection if no current user but other users", async () => {
const users = {
"A!": "A",
"B!": "B",
"C!": "C",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Check user list
const select = selection.getElementsByTagName("select")[0];
const options = select.getElementsByTagName("option");
expect(
[...options]
.filter((o) => !o.disabled)
.reduce((p, n) => {
p[n.getAttribute("value")] = n.textContent;
return p;
}, {})
).toStrictEqual(users);
// Select an option
select.focus();
select.value = options[2].value;
return false;
}, users);
expect(users).toStrictEqual(users);
expect(localStorage["Comfy.userId"]).toBe("B!");
expect(localStorage["Comfy.userName"]).toBe("B");
});
it("doesnt show user screen if current user", async () => {
const starting = start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
await new Promise(process.nextTick); // wait for promises to resolve
expectNoUserScreen();
await starting;
});
it("allows user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
// cant actually test switching user easily but can check the setting is present
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
});
});
describe("single-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(app.isNewUserSession).toBeTruthy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
describe("browser-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "browser" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../comfy/web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "browser" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
});

View File

@ -1,10 +1,18 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
const lg = require("./litegraph");
const fs = require("fs");
const path = require("path");
const html = fs.readFileSync(path.resolve("../comfy/web/index.html"))
/**
*
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
* @param { Parameters<typeof mockApi>[0] & {
* resetEnv?: boolean,
* preSetup?(app): Promise<void>,
* localStorage?: Record<string, string>
* } } config
* @returns
*/
export async function start(config = {}) {
@ -12,12 +20,18 @@ export async function start(config = {}) {
jest.resetModules();
jest.resetAllMocks();
lg.setup(global);
localStorage.clear();
sessionStorage.clear();
}
Object.assign(localStorage, config.localStorage ?? {});
document.body.innerHTML = html;
mockApi(config);
const { app } = require("../../comfy/web/scripts/app");
config.preSetup?.(app);
await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}

View File

@ -18,9 +18,21 @@ function* walkSync(dir) {
*/
/**
* @param { { mockExtensions?: string[], mockNodeDefs?: Record<string, ComfyObjectInfo> } } config
* @param {{
* mockExtensions?: string[],
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
* settings?: Record<string, string>
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
* userData?: Record<string, any>
* }} config
*/
export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
export function mockApi(config = {}) {
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
userConfig,
settings: {},
userData: {},
...config,
};
if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core")))
.filter((x) => x.endsWith(".js"))
@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
createUser: jest.fn((username) => {
if(username in userConfig.users) {
return { status: 400, json: () => "Duplicate" }
}
userConfig.users[username + "!"] = username;
return { status: 200, json: () => username + "!" }
}),
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
getSettings: jest.fn(() => settings),
storeSettings: jest.fn((v) => Object.assign(settings, v)),
getUserData: jest.fn((f) => {
if (f in userData) {
return { status: 200, json: () => userData[f] };
} else {
return { status: 404 };
}
}),
storeUserData: jest.fn((file, data) => {
userData[file] = data;
}),
};
jest.mock("../../comfy/web/scripts/api", () => ({
get api() {