From 6d281b4ff4ad3918a4f3b4ca4a8b547a2ba3bf80 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Jan 2024 14:28:11 -0500 Subject: [PATCH 01/54] Add a /free route to unload models or free all memory. A POST request to /free with: {"unload_models":true} will unload models from vram. A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow. --- execution.py | 20 +++++++++++++++++++- main.py | 15 ++++++++++++++- server.py | 11 +++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 53ba2e0f8..260a08970 100644 --- a/execution.py +++ b/execution.py @@ -268,11 +268,14 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): + self.server = server + self.reset() + + def reset(self): self.outputs = {} self.object_storage = {} self.outputs_ui = {} self.old_prompt = {} - self.server = server def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -706,6 +709,7 @@ class PromptQueue: self.queue = [] self.currently_running = {} self.history = {} + self.flags = {} server.prompt_queue = self def put(self, item): @@ -792,3 +796,17 @@ class PromptQueue: def delete_history_item(self, id_to_delete): with self.mutex: self.history.pop(id_to_delete, None) + + def set_flag(self, name, data): + with self.mutex: + self.flags[name] = data + self.not_empty.notify() + + def get_flags(self, reset=True): + with self.mutex: + if reset: + ret = self.flags + self.flags = {} + return ret + else: + return self.flags.copy() diff --git a/main.py b/main.py index bcf738729..45d5d41b3 100644 --- a/main.py +++ b/main.py @@ -97,7 +97,7 @@ def prompt_worker(q, server): gc_collect_interval = 10.0 while True: - timeout = None + timeout = 1000.0 if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) @@ -118,6 +118,19 @@ def prompt_worker(q, server): execution_time = current_time - execution_start_time print("Prompt executed in {:.2f} seconds".format(execution_time)) + flags = q.get_flags() + free_memory = flags.get("free_memory", False) + + if flags.get("unload_models", free_memory): + comfy.model_management.unload_all_models() + need_gc = True + last_gc_collect = 0 + + if free_memory: + e.reset() + need_gc = True + last_gc_collect = 0 + if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: diff --git a/server.py b/server.py index bd6f026b2..acb8875d2 100644 --- a/server.py +++ b/server.py @@ -507,6 +507,17 @@ class PromptServer(): nodes.interrupt_processing() return web.Response(status=200) + @routes.post("/free") + async def post_interrupt(request): + json_data = await request.json() + unload_models = json_data.get("unload_models", False) + free_memory = json_data.get("free_memory", False) + if unload_models: + self.prompt_queue.set_flag("unload_models", unload_models) + if free_memory: + self.prompt_queue.set_flag("free_memory", free_memory) + return web.Response(status=200) + @routes.post("/history") async def post_history(request): json_data = await request.json() From 35322a376630e087d65f3108b0a6ec7a4b39279b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 Jan 2024 04:20:03 -0500 Subject: [PATCH 02/54] StableZero123_Conditioning_Batched node. This node lets you generate a batch of images with different elevations or azimuths by setting the elevation_batch_increment and/or azimuth_batch_increment. It also sets the batch index for the latents so that the same init noise is used on each frame. --- comfy_extras/nodes_stable3d.py | 44 ++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index c6791d8de..e02a9875a 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -53,6 +53,50 @@ class StableZero123_Conditioning: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent}) +class StableZero123_Conditioning_Batched: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + cam_embeds = [] + for i in range(batch_size): + cam_embeds.append(camera_embeddings(elevation, azimuth)) + elevation += elevation_batch_increment + azimuth += azimuth_batch_increment + + cam_embeds = torch.cat(cam_embeds, dim=0) + cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, + "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, } From 7c9a0f7e0ad212719cd2c6179ae6266d949ea401 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 Jan 2024 12:31:13 -0500 Subject: [PATCH 03/54] Fix BasicScheduler issue with Loras. --- comfy_extras/nodes_custom_sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index bb0ed57b2..f02cb5ef7 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -26,9 +26,8 @@ class BasicScheduler: if denoise < 1.0: total_steps = int(steps/denoise) - inner_model = model.patch_model(patch_weights=False) - sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu() - model.unpatch_model() + comfy.model_management.load_models_gpu([model]) + sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] return (sigmas, ) From af94eb14e3b44f6e6c48a7762a5e229cf3a006e4 Mon Sep 17 00:00:00 2001 From: ramyma Date: Sat, 6 Jan 2024 04:27:09 +0200 Subject: [PATCH 04/54] fix: `/free` handler function name --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index acb8875d2..5b821dcbf 100644 --- a/server.py +++ b/server.py @@ -508,7 +508,7 @@ class PromptServer(): return web.Response(status=200) @routes.post("/free") - async def post_interrupt(request): + async def post_free(request): json_data = await request.json() unload_models = json_data.get("unload_models", False) free_memory = json_data.get("free_memory", False) From 3ad0191bfb7674486734db98769ab466f27e9362 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Jan 2024 04:33:03 -0500 Subject: [PATCH 05/54] Implement attention mask on xformers. --- comfy/ldm/modules/attention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3e12886b0..14d41a8cd 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -294,11 +294,14 @@ def attention_xformers(q, k, v, heads, mask=None): (q, k, v), ) - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if exists(mask): - raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) From 0c2c9fbdfa53c2ad3b7658a7f2300da831830388 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Jan 2024 13:16:48 -0500 Subject: [PATCH 06/54] Support attention mask in split attention. --- comfy/ldm/modules/attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 14d41a8cd..a18a69294 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None): else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 first_op_done = True From aaa9017302b75cf4453b5f8a58788e121f8e0a39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 7 Jan 2024 04:13:58 -0500 Subject: [PATCH 07/54] Add attention mask support to sub quad attention. --- comfy/ldm/modules/attention.py | 1 + comfy/ldm/modules/sub_quadratic_attention.py | 30 +++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a18a69294..8015a307a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=False, upcast_attention=upcast_attention, + mask=mask, ) hidden_states = hidden_states.to(dtype) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 8e8e8054d..cb0896b0d 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -61,6 +61,7 @@ def _summarize_chunk( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> AttnChunk: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -84,6 +85,8 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() attn_weights -= max_score + if mask is not None: + attn_weights += mask torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) @@ -96,11 +99,12 @@ def _query_chunk_attention( value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, + mask, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape - def chunk_scanner(chunk_idx: int) -> AttnChunk: + def chunk_scanner(chunk_idx: int, mask) -> AttnChunk: key_chunk = dynamic_slice( key_t, (0, 0, chunk_idx), @@ -111,10 +115,13 @@ def _query_chunk_attention( (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) - return summarize_chunk(query, key_chunk, value_chunk) + if mask is not None: + mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + + return summarize_chunk(query, key_chunk, value_chunk, mask=mask) chunks: List[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk @@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking( beta=0, ) + if mask is not None: + attn_scores += mask try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores @@ -183,6 +193,7 @@ def efficient_dot_product_attention( kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, + mask = None, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -209,13 +220,22 @@ def efficient_dot_product_attention( if kv_chunk_size_min is not None: kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + if mask is not None and len(mask.shape) == 2: + mask = mask.unsqueeze(0) + def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - + + def get_mask_chunk(chunk_idx: int) -> Tensor: + if mask is None: + return None + chunk = min(query_chunk_size, q_tokens) + return mask[:,chunk_idx:chunk_idx + chunk] + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( @@ -237,6 +257,7 @@ def efficient_dot_product_attention( query=query, key_t=key_t, value=value, + mask=mask, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, @@ -246,6 +267,7 @@ def efficient_dot_product_attention( query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, + mask=get_mask_chunk(i * query_chunk_size) ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res From c6951548cfec64c28082e6560c69c59e32729c9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 7 Jan 2024 13:52:08 -0500 Subject: [PATCH 08/54] Update optimized_attention_for_device function for new functions that support masked attention. --- comfy/clip_model.py | 2 +- comfy/ldm/modules/attention.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7397b7a26..09e7bbca1 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module): self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) def forward(self, x, mask=None, intermediate_output=None): - optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) if intermediate_output is not None: if intermediate_output < 0: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8015a307a..309240d5c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -333,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic -optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -349,15 +348,15 @@ else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -if model_management.pytorch_attention_enabled(): - optimized_attention_masked = attention_pytorch +optimized_attention_masked = optimized_attention + +def optimized_attention_for_device(device, mask=False, small_input=False): + if small_input and model_management.pytorch_attention_enabled(): + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + + if device == torch.device("cpu"): + return attention_sub_quad -def optimized_attention_for_device(device, mask=False): - if device == torch.device("cpu"): #TODO - if model_management.pytorch_attention_enabled(): - return attention_pytorch - else: - return attention_basic if mask: return optimized_attention_masked From 6a10640f0dd019dd7c74006909f38d0056c317bd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 Jan 2024 03:46:36 -0500 Subject: [PATCH 09/54] Support properly loading images with mode I. --- nodes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nodes.py b/nodes.py index 82244cf76..1ea0669b1 100644 --- a/nodes.py +++ b/nodes.py @@ -1415,6 +1415,8 @@ class LoadImage: output_masks = [] for i in ImageSequence.Iterator(img): i = ImageOps.exif_transpose(i) + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] From 235727fed79880ac2053a1db1b0d13c0f75714e8 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 8 Jan 2024 22:06:44 +0000 Subject: [PATCH 10/54] Store user settings/data on the server and multi user support (#2160) * wip per user data * Rename, hide menu * better error rework default user * store pretty * Add userdata endpoints Change nodetemplates to userdata * add multi user message * make normal arg * Fix tests * Ignore user dir * user tests * Changed to default to browser storage and add server-storage arg * fix crash on empty templates * fix settings added before load * ignore parse errors --- .gitignore | 3 +- app/app_settings.py | 54 +++++ app/user_manager.py | 141 ++++++++++++ comfy/cli_args.py | 6 + folder_paths.py | 1 + server.py | 3 + tests-ui/babel.config.json | 3 +- tests-ui/package-lock.json | 20 ++ tests-ui/package.json | 1 + tests-ui/tests/users.test.js | 295 +++++++++++++++++++++++++ tests-ui/utils/index.js | 16 +- tests-ui/utils/setup.js | 36 +++- web/extensions/core/nodeTemplates.js | 64 ++++-- web/index.html | 30 ++- web/scripts/api.js | 100 +++++++++ web/scripts/app.js | 82 +++++++ web/scripts/ui.js | 269 +---------------------- web/scripts/ui/dialog.js | 32 +++ web/scripts/ui/settings.js | 307 +++++++++++++++++++++++++++ web/scripts/ui/spinner.css | 34 +++ web/scripts/ui/spinner.js | 9 + web/scripts/ui/userSelection.css | 135 ++++++++++++ web/scripts/ui/userSelection.js | 114 ++++++++++ web/scripts/utils.js | 21 ++ web/style.css | 2 + 25 files changed, 1496 insertions(+), 282 deletions(-) create mode 100644 app/app_settings.py create mode 100644 app/user_manager.py create mode 100644 tests-ui/tests/users.test.js create mode 100644 web/scripts/ui/dialog.js create mode 100644 web/scripts/ui/settings.js create mode 100644 web/scripts/ui/spinner.css create mode 100644 web/scripts/ui/spinner.js create mode 100644 web/scripts/ui/userSelection.css create mode 100644 web/scripts/ui/userSelection.js diff --git a/.gitignore b/.gitignore index 43c038e41..9f0389241 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ venv/ /web/extensions/* !/web/extensions/logging.js.example !/web/extensions/core/ -/tests-ui/data/object_info.json \ No newline at end of file +/tests-ui/data/object_info.json +/user/ \ No newline at end of file diff --git a/app/app_settings.py b/app/app_settings.py new file mode 100644 index 000000000..8c6edc56c --- /dev/null +++ b/app/app_settings.py @@ -0,0 +1,54 @@ +import os +import json +from aiohttp import web + + +class AppSettings(): + def __init__(self, user_manager): + self.user_manager = user_manager + + def get_settings(self, request): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + if os.path.isfile(file): + with open(file) as f: + return json.load(f) + else: + return {} + + def save_settings(self, request, settings): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + with open(file, "w") as f: + f.write(json.dumps(settings, indent=4)) + + def add_routes(self, routes): + @routes.get("/settings") + async def get_settings(request): + return web.json_response(self.get_settings(request)) + + @routes.get("/settings/{id}") + async def get_setting(request): + value = None + settings = self.get_settings(request) + setting_id = request.match_info.get("id", None) + if setting_id and setting_id in settings: + value = settings[setting_id] + return web.json_response(value) + + @routes.post("/settings") + async def post_settings(request): + settings = self.get_settings(request) + new_settings = await request.json() + self.save_settings(request, {**settings, **new_settings}) + return web.Response(status=200) + + @routes.post("/settings/{id}") + async def post_setting(request): + setting_id = request.match_info.get("id", None) + if not setting_id: + return web.Response(status=400) + settings = self.get_settings(request) + settings[setting_id] = await request.json() + self.save_settings(request, settings) + return web.Response(status=200) \ No newline at end of file diff --git a/app/user_manager.py b/app/user_manager.py new file mode 100644 index 000000000..5a1031744 --- /dev/null +++ b/app/user_manager.py @@ -0,0 +1,141 @@ +import json +import os +import re +import uuid +from aiohttp import web +from comfy.cli_args import args +from folder_paths import user_directory +from .app_settings import AppSettings + +default_user = "default" +users_file = os.path.join(user_directory, "users.json") + + +class UserManager(): + def __init__(self): + global user_directory + + self.settings = AppSettings(self) + if not os.path.exists(user_directory): + os.mkdir(user_directory) + if not args.multi_user: + print("****** User settings have been changed to be stored on the server instead of browser storage. ******") + print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") + + if args.multi_user: + if os.path.isfile(users_file): + with open(users_file) as f: + self.users = json.load(f) + else: + self.users = {} + else: + self.users = {"default": "default"} + + def get_request_user_id(self, request): + user = "default" + if args.multi_user and "comfy-user" in request.headers: + user = request.headers["comfy-user"] + + if user not in self.users: + raise KeyError("Unknown user: " + user) + + return user + + def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): + global user_directory + + if type == "userdata": + root_dir = user_directory + else: + raise KeyError("Unknown filepath type:" + type) + + user = self.get_request_user_id(request) + path = user_root = os.path.abspath(os.path.join(root_dir, user)) + + # prevent leaving /{type} + if os.path.commonpath((root_dir, user_root)) != root_dir: + return None + + parent = user_root + + if file is not None: + # prevent leaving /{type}/{user} + path = os.path.abspath(os.path.join(user_root, file)) + if os.path.commonpath((user_root, path)) != user_root: + return None + parent = os.path.join(path, os.pardir) + + if create_dir and not os.path.exists(parent): + os.mkdir(parent) + + return path + + def add_user(self, name): + name = name.strip() + if not name: + raise ValueError("username not provided") + user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + user_id = user_id + "_" + str(uuid.uuid4()) + + self.users[user_id] = name + + global users_file + with open(users_file, "w") as f: + json.dump(self.users, f) + + return user_id + + def add_routes(self, routes): + self.settings.add_routes(routes) + + @routes.get("/users") + async def get_users(request): + if args.multi_user: + return web.json_response({"storage": "server", "users": self.users}) + else: + user_dir = self.get_request_user_filepath(request, None, create_dir=False) + return web.json_response({ + "storage": "server" if args.server_storage else "browser", + "migrated": os.path.exists(user_dir) + }) + + @routes.post("/users") + async def post_users(request): + body = await request.json() + username = body["username"] + if username in self.users.values(): + return web.json_response({"error": "Duplicate username."}, status=400) + + user_id = self.add_user(username) + return web.json_response(user_id) + + @routes.get("/userdata/{file}") + async def getuserdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + if not os.path.exists(path): + return web.Response(status=404) + + return web.FileResponse(path) + + @routes.post("/userdata/{file}") + async def post_userdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + body = await request.read() + with open(path, "wb") as f: + f.write(body) + + return web.Response(status=200) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 50d7b62fa..c02bbf2ba 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,9 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") +parser.add_argument("--server-storage", action="store_true", help="Saves settings and other user configuration on the server instead of in browser storage.") +parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage. If enabled, server-storage will be unconditionally enabled.") + if comfy.options.args_parsing: args = parser.parse_args() else: @@ -122,3 +125,6 @@ if args.windows_standalone_build: if args.disable_auto_launch: args.auto_launch = False + +if args.multi_user: + args.server_storage = True \ No newline at end of file diff --git a/folder_paths.py b/folder_paths.py index a8726d8dd..641e5f0b2 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -34,6 +34,7 @@ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers" output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") filename_list_cache = {} diff --git a/server.py b/server.py index 5b821dcbf..1135fff88 100644 --- a/server.py +++ b/server.py @@ -30,6 +30,7 @@ from comfy.cli_args import args import comfy.utils import comfy.model_management +from app.user_manager import UserManager class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -72,6 +73,7 @@ class PromptServer(): mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' + self.user_manager = UserManager() self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop @@ -532,6 +534,7 @@ class PromptServer(): return web.Response(status=200) def add_routes(self): + self.user_manager.add_routes(self.routes) self.app.add_routes(self.routes) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): diff --git a/tests-ui/babel.config.json b/tests-ui/babel.config.json index 526ddfd8d..f27d6c397 100644 --- a/tests-ui/babel.config.json +++ b/tests-ui/babel.config.json @@ -1,3 +1,4 @@ { - "presets": ["@babel/preset-env"] + "presets": ["@babel/preset-env"], + "plugins": ["babel-plugin-transform-import-meta"] } diff --git a/tests-ui/package-lock.json b/tests-ui/package-lock.json index 35911cd7f..0f409ca24 100644 --- a/tests-ui/package-lock.json +++ b/tests-ui/package-lock.json @@ -11,6 +11,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } @@ -2591,6 +2592,19 @@ "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" } }, + "node_modules/babel-plugin-transform-import-meta": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz", + "integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==", + "dev": true, + "dependencies": { + "@babel/template": "^7.4.4", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@babel/core": "^7.10.0" + } + }, "node_modules/babel-preset-current-node-syntax": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", @@ -5233,6 +5247,12 @@ "node": ">=12" } }, + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "dev": true + }, "node_modules/type-detect": { "version": "4.0.8", "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", diff --git a/tests-ui/package.json b/tests-ui/package.json index e7b60ad8e..ae7e49084 100644 --- a/tests-ui/package.json +++ b/tests-ui/package.json @@ -24,6 +24,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } diff --git a/tests-ui/tests/users.test.js b/tests-ui/tests/users.test.js new file mode 100644 index 000000000..5e0730730 --- /dev/null +++ b/tests-ui/tests/users.test.js @@ -0,0 +1,295 @@ +// @ts-check +/// +const { start } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("users", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + function expectNoUserScreen() { + // Ensure login isnt visible + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection["style"].display).toBe("none"); + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + } + + describe("multi-user", () => { + function mockAddStylesheet() { + const utils = require("../../web/scripts/utils"); + utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve()); + } + + async function waitForUserScreenShow() { + mockAddStylesheet(); + + // Wait for "show" to be called + const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection"); + let resolve, reject; + const fn = UserSelectionScreen.prototype.show; + const p = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => { + const res = fn(...args); + await new Promise(process.nextTick); // wait for promises to resolve + resolve(); + return res; + }); + // @ts-ignore + setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500); + await p; + await new Promise(process.nextTick); // wait for promises to resolve + } + + async function testUserScreen(onShown, users) { + if (!users) { + users = {}; + } + const starting = start({ + resetEnv: true, + userConfig: { storage: "server", users }, + }); + + // Ensure no current user + expect(localStorage["Comfy.userId"]).toBeFalsy(); + expect(localStorage["Comfy.userName"]).toBeFalsy(); + + await waitForUserScreenShow(); + + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection).toBeTruthy(); + + // Ensure login is visible + expect(window.getComputedStyle(selection)?.display).not.toBe("none"); + // Ensure menu is hidden + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).toBe("none"); + + const isCreate = await onShown(selection); + + // Submit form + selection.querySelectorAll("form")[0].submit(); + await new Promise(process.nextTick); // wait for promises to resolve + + // Wait for start + const s = await starting; + + // Ensure login is removed + expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0); + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + + // Ensure settings + templates are saved + const { api } = require("../../web/scripts/api"); + expect(api.createUser).toHaveBeenCalledTimes(+isCreate); + expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate); + expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate); + if (isCreate) { + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(s.app.isNewUserSession).toBeTruthy(); + } else { + expect(s.app.isNewUserSession).toBeFalsy(); + } + + return { users, selection, ...s }; + } + + it("allows user creation if no users", async () => { + const { users } = await testUserScreen((selection) => { + // Ensure we have no users flag added + expect(selection.classList.contains("no-users")).toBeTruthy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User"; + + return true; + }); + + expect(users).toStrictEqual({ + "Test User!": "Test User", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User!"); + expect(localStorage["Comfy.userName"]).toBe("Test User"); + }); + it("allows user creation if no current user but other users", async () => { + const users = { + "Test User 2!": "Test User 2", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User 3"; + return true; + }, users); + + expect(users).toStrictEqual({ + "Test User 2!": "Test User 2", + "Test User 3!": "Test User 3", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User 3!"); + expect(localStorage["Comfy.userName"]).toBe("Test User 3"); + }); + it("allows user selection if no current user but other users", async () => { + const users = { + "A!": "A", + "B!": "B", + "C!": "C", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Check user list + const select = selection.getElementsByTagName("select")[0]; + const options = select.getElementsByTagName("option"); + expect( + [...options] + .filter((o) => !o.disabled) + .reduce((p, n) => { + p[n.getAttribute("value")] = n.textContent; + return p; + }, {}) + ).toStrictEqual(users); + + // Select an option + select.focus(); + select.value = options[2].value; + + return false; + }, users); + + expect(users).toStrictEqual(users); + + expect(localStorage["Comfy.userId"]).toBe("B!"); + expect(localStorage["Comfy.userName"]).toBe("B"); + }); + it("doesnt show user screen if current user", async () => { + const starting = start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + await new Promise(process.nextTick); // wait for promises to resolve + + expectNoUserScreen(); + + await starting; + }); + it("allows user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + + // cant actually test switching user easily but can check the setting is present + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy(); + }); + }); + describe("single-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(app.isNewUserSession).toBeTruthy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); + describe("browser-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "browser" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "browser" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); +}); diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 6a08e8594..74b6cf93d 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -1,10 +1,18 @@ const { mockApi } = require("./setup"); const { Ez } = require("./ezgraph"); const lg = require("./litegraph"); +const fs = require("fs"); +const path = require("path"); + +const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html")) /** * - * @param { Parameters[0] & { resetEnv?: boolean, preSetup?(app): Promise } } config + * @param { Parameters[0] & { + * resetEnv?: boolean, + * preSetup?(app): Promise, + * localStorage?: Record + * } } config * @returns */ export async function start(config = {}) { @@ -12,12 +20,18 @@ export async function start(config = {}) { jest.resetModules(); jest.resetAllMocks(); lg.setup(global); + localStorage.clear(); + sessionStorage.clear(); } + Object.assign(localStorage, config.localStorage ?? {}); + document.body.innerHTML = html; + mockApi(config); const { app } = require("../../web/scripts/app"); config.preSetup?.(app); await app.setup(); + return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js index dd150214a..e46258943 100644 --- a/tests-ui/utils/setup.js +++ b/tests-ui/utils/setup.js @@ -18,9 +18,21 @@ function* walkSync(dir) { */ /** - * @param { { mockExtensions?: string[], mockNodeDefs?: Record } } config + * @param {{ + * mockExtensions?: string[], + * mockNodeDefs?: Record, +* settings?: Record +* userConfig?: {storage: "server" | "browser", users?: Record, migrated?: boolean }, +* userData?: Record + * }} config */ -export function mockApi({ mockExtensions, mockNodeDefs } = {}) { +export function mockApi(config = {}) { + let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = { + userConfig, + settings: {}, + userData: {}, + ...config, + }; if (!mockExtensions) { mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core"))) .filter((x) => x.endsWith(".js")) @@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) { getNodeDefs: jest.fn(() => mockNodeDefs), init: jest.fn(), apiURL: jest.fn((x) => "../../web/" + x), + createUser: jest.fn((username) => { + if(username in userConfig.users) { + return { status: 400, json: () => "Duplicate" } + } + userConfig.users[username + "!"] = username; + return { status: 200, json: () => username + "!" } + }), + getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }), + getSettings: jest.fn(() => settings), + storeSettings: jest.fn((v) => Object.assign(settings, v)), + getUserData: jest.fn((f) => { + if (f in userData) { + return { status: 200, json: () => userData[f] }; + } else { + return { status: 404 }; + } + }), + storeUserData: jest.fn((file, data) => { + userData[file] = data; + }), }; jest.mock("../../web/scripts/api", () => ({ get api() { diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index bc9a10864..9350ba654 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; import { ComfyDialog, $el } from "../../scripts/ui.js"; import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; @@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; // Open the manage dialog and Drag and drop elements using the "Name:" label as handle const id = "Comfy.NodeTemplates"; +const file = "comfy.templates.json"; class ManageTemplates extends ComfyDialog { constructor() { super(); + this.load().then((v) => { + this.templates = v; + }); + this.element.classList.add("comfy-manage-templates"); - this.templates = this.load(); this.draggedEl = null; this.saveVisualCue = null; this.emptyImg = new Image(); - this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs='; + this.emptyImg.src = "data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs="; this.importInput = $el("input", { type: "file", @@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog { return btns; } - load() { - const templates = localStorage.getItem(id); - if (templates) { - return JSON.parse(templates); + async load() { + let templates = []; + if (app.storageLocation === "server") { + if (app.isNewUserSession) { + // New user so migrate existing templates + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } + await api.storeUserData(file, json, { stringify: false }); + } else { + const res = await api.getUserData(file); + if (res.status === 200) { + try { + templates = await res.json(); + } catch (error) { + } + } else if (res.status !== 404) { + console.error(res.status + " " + res.statusText); + } + } } else { - return []; + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } } + + return templates ?? []; } - store() { - localStorage.setItem(id, JSON.stringify(this.templates)); + async store() { + if(app.storageLocation === "server") { + const templates = JSON.stringify(this.templates, undefined, 4); + localStorage.setItem(id, templates); // Backwards compatibility + try { + await api.storeUserData(file, templates, { stringify: false }); + } catch (error) { + console.error(error); + alert(error.message); + } + } else { + localStorage.setItem(id, JSON.stringify(this.templates)); + } } async importAll() { @@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog { if (file.type === "application/json" || file.name.endsWith(".json")) { const reader = new FileReader(); reader.onload = async () => { - var importFile = JSON.parse(reader.result); - if (importFile && importFile?.templates) { + const importFile = JSON.parse(reader.result); + if (importFile?.templates) { for (const template of importFile.templates) { if (template?.name && template?.data) { this.templates.push(template); } } - this.store(); + await this.store(); } }; await reader.readAsText(file); @@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog { e.currentTarget.style.border = "1px dashed transparent"; e.currentTarget.removeAttribute("draggable"); - // rearrange the elements in the localStorage + // rearrange the elements this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { var prev_i = el.dataset.id; diff --git a/web/index.html b/web/index.html index 41bc246c0..094db9d15 100644 --- a/web/index.html +++ b/web/index.html @@ -16,5 +16,33 @@ window.graph = app.graph; - + + + diff --git a/web/scripts/api.js b/web/scripts/api.js index 9aa7528af..3a9bcc87a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -12,6 +12,13 @@ class ComfyApi extends EventTarget { } fetchApi(route, options) { + if (!options) { + options = {}; + } + if (!options.headers) { + options.headers = {}; + } + options.headers["Comfy-User"] = this.user; return fetch(this.apiURL(route), options); } @@ -315,6 +322,99 @@ class ComfyApi extends EventTarget { async interrupt() { await this.#postItem("interrupt", null); } + + /** + * Gets user configuration data and where data should be stored + * @returns { Promise<{ storage: "server" | "browser", users?: Promise, migrated?: boolean }> } + */ + async getUserConfig() { + return (await this.fetchApi("/users")).json(); + } + + /** + * Creates a new user + * @param { string } username + * @returns The fetch response + */ + createUser(username) { + return this.fetchApi("/users", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ username }), + }); + } + + /** + * Gets all setting values for the current user + * @returns { Promise } A dictionary of id -> value + */ + async getSettings() { + return (await this.fetchApi("/settings")).json(); + } + + /** + * Gets a setting for the current user + * @param { string } id The id of the setting to fetch + * @returns { Promise } The setting value + */ + async getSetting(id) { + return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json(); + } + + /** + * Stores a dictionary of settings for the current user + * @param { Record } settings Dictionary of setting id -> value to save + * @returns { Promise } + */ + async storeSettings(settings) { + return this.fetchApi(`/settings`, { + method: "POST", + body: JSON.stringify(settings) + }); + } + + /** + * Stores a setting for the current user + * @param { string } id The id of the setting to update + * @param { unknown } value The value of the setting + * @returns { Promise } + */ + async storeSetting(id, value) { + return this.fetchApi(`/settings/${encodeURIComponent(id)}`, { + method: "POST", + body: JSON.stringify(value) + }); + } + + /** + * Gets a user data file for the current user + * @param { string } file The name of the userdata file to load + * @param { RequestInit } [options] + * @returns { Promise } The fetch response object + */ + async getUserData(file, options) { + return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options); + } + + /** + * Stores a user data file for the current user + * @param { string } file The name of the userdata file to save + * @param { unknown } data The data to save to the file + * @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options] + * @returns { Promise } + */ + async storeUserData(file, data, options = { stringify: true, throwOnError: true }) { + const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, { + method: "POST", + body: options?.stringify ? JSON.stringify(data) : data, + ...options, + }); + if (resp.status !== 200) { + throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`); + } + } } export const api = new ComfyApi(); diff --git a/web/scripts/app.js b/web/scripts/app.js index 7353f5a3b..72f9e8603 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1291,10 +1291,92 @@ export class ComfyApp { await Promise.all(extensionPromises); } + async #migrateSettings() { + this.isNewUserSession = true; + // Store all current settings + const settings = Object.keys(this.ui.settings).reduce((p, n) => { + const v = localStorage[`Comfy.Settings.${n}`]; + if (v) { + try { + p[n] = JSON.parse(v); + } catch (error) {} + } + return p; + }, {}); + + await api.storeSettings(settings); + } + + async #setUser() { + const userConfig = await api.getUserConfig(); + this.storageLocation = userConfig.storage; + if (typeof userConfig.migrated == "boolean") { + // Single user mode migrated true/false for if the default user is created + if (!userConfig.migrated && this.storageLocation === "server") { + // Default user not created yet + await this.#migrateSettings(); + } + return; + } + + this.multiUserServer = true; + let user = localStorage["Comfy.userId"]; + const users = userConfig.users ?? {}; + if (!user || !users[user]) { + // This will rarely be hit so move the loading to on demand + const { UserSelectionScreen } = await import("./ui/userSelection.js"); + + this.ui.menuContainer.style.display = "none"; + const { userId, username, created } = await new UserSelectionScreen().show(users, user); + this.ui.menuContainer.style.display = ""; + + user = userId; + localStorage["Comfy.userName"] = username; + localStorage["Comfy.userId"] = user; + + if (created) { + api.user = user; + await this.#migrateSettings(); + } + } + + api.user = user; + + this.ui.settings.addSetting({ + id: "Comfy.SwitchUser", + name: "Switch User", + type: (name) => { + let currentUser = localStorage["Comfy.userName"]; + if (currentUser) { + currentUser = ` (${currentUser})`; + } + return $el("tr", [ + $el("td", [ + $el("label", { + textContent: name, + }), + ]), + $el("td", [ + $el("button", { + textContent: name + (currentUser ?? ""), + onclick: () => { + delete localStorage["Comfy.userId"]; + delete localStorage["Comfy.userName"]; + window.location.reload(); + }, + }), + ]), + ]); + }, + }); + } + /** * Set up the app on the page */ async setup() { + await this.#setUser(); + await this.ui.settings.load(); await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM diff --git a/web/scripts/ui.js b/web/scripts/ui.js index ebaf86fe4..6887f7028 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -1,4 +1,8 @@ -import {api} from "./api.js"; +import { api } from "./api.js"; +import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js"; +import { ComfySettingsDialog } from "./ui/settings.js"; + +export const ComfyDialog = _ComfyDialog; export function $el(tag, propsOrChildren, children) { const split = tag.split("."); @@ -167,267 +171,6 @@ function dragElement(dragEl, settings) { } } -export class ComfyDialog { - constructor() { - this.element = $el("div.comfy-modal", {parent: document.body}, [ - $el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]), - ]); - } - - createButtons() { - return [ - $el("button", { - type: "button", - textContent: "Close", - onclick: () => this.close(), - }), - ]; - } - - close() { - this.element.style.display = "none"; - } - - show(html) { - if (typeof html === "string") { - this.textElement.innerHTML = html; - } else { - this.textElement.replaceChildren(html); - } - this.element.style.display = "flex"; - } -} - -class ComfySettingsDialog extends ComfyDialog { - constructor() { - super(); - this.element = $el("dialog", { - id: "comfy-settings-dialog", - parent: document.body, - }, [ - $el("table.comfy-modal-content.comfy-table", [ - $el("caption", {textContent: "Settings"}), - $el("tbody", {$: (tbody) => (this.textElement = tbody)}), - $el("button", { - type: "button", - textContent: "Close", - style: { - cursor: "pointer", - }, - onclick: () => { - this.element.close(); - }, - }), - ]), - ]); - this.settings = []; - } - - getSettingValue(id, defaultValue) { - const settingId = "Comfy.Settings." + id; - const v = localStorage[settingId]; - return v == null ? defaultValue : JSON.parse(v); - } - - setSettingValue(id, value) { - const settingId = "Comfy.Settings." + id; - localStorage[settingId] = JSON.stringify(value); - } - - addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) { - if (!id) { - throw new Error("Settings must have an ID"); - } - - if (this.settings.find((s) => s.id === id)) { - throw new Error(`Setting ${id} of type ${type} must have a unique ID.`); - } - - const settingId = `Comfy.Settings.${id}`; - const v = localStorage[settingId]; - let value = v == null ? defaultValue : JSON.parse(v); - - // Trigger initial setting of value - if (onChange) { - onChange(value, undefined); - } - - this.settings.push({ - render: () => { - const setter = (v) => { - if (onChange) { - onChange(v, value); - } - localStorage[settingId] = JSON.stringify(v); - value = v; - }; - value = this.getSettingValue(id, defaultValue); - - let element; - const htmlID = id.replaceAll(".", "-"); - - const labelCell = $el("td", [ - $el("label", { - for: htmlID, - classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""], - textContent: name, - }) - ]); - - if (typeof type === "function") { - element = type(name, setter, value, attrs); - } else { - switch (type) { - case "boolean": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - id: htmlID, - type: "checkbox", - checked: value, - onchange: (event) => { - const isChecked = event.target.checked; - if (onChange !== undefined) { - onChange(isChecked) - } - this.setSettingValue(id, isChecked); - }, - }), - ]), - ]) - break; - case "number": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - type, - value, - id: htmlID, - oninput: (e) => { - setter(e.target.value); - }, - ...attrs - }), - ]), - ]); - break; - case "slider": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("div", { - style: { - display: "grid", - gridAutoFlow: "column", - }, - }, [ - $el("input", { - ...attrs, - value, - type: "range", - oninput: (e) => { - setter(e.target.value); - e.target.nextElementSibling.value = e.target.value; - }, - }), - $el("input", { - ...attrs, - value, - id: htmlID, - type: "number", - style: {maxWidth: "4rem"}, - oninput: (e) => { - setter(e.target.value); - e.target.previousElementSibling.value = e.target.value; - }, - }), - ]), - ]), - ]); - break; - case "combo": - element = $el("tr", [ - labelCell, - $el("td", [ - $el( - "select", - { - oninput: (e) => { - setter(e.target.value); - }, - }, - (typeof options === "function" ? options(value) : options || []).map((opt) => { - if (typeof opt === "string") { - opt = { text: opt }; - } - const v = opt.value ?? opt.text; - return $el("option", { - value: v, - textContent: opt.text, - selected: value + "" === v + "", - }); - }) - ), - ]), - ]); - break; - case "text": - default: - if (type !== "text") { - console.warn(`Unsupported setting type '${type}, defaulting to text`); - } - - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - value, - id: htmlID, - oninput: (e) => { - setter(e.target.value); - }, - ...attrs, - }), - ]), - ]); - break; - } - } - if (tooltip) { - element.title = tooltip; - } - - return element; - }, - }); - - const self = this; - return { - get value() { - return self.getSettingValue(id, defaultValue); - }, - set value(v) { - self.setSettingValue(id, v); - }, - }; - } - - show() { - this.textElement.replaceChildren( - $el("tr", { - style: {display: "none"}, - }, [ - $el("th"), - $el("th", {style: {width: "33%"}}) - ]), - ...this.settings.map((s) => s.render()), - ) - this.element.showModal(); - } -} - class ComfyList { #type; #text; @@ -526,7 +269,7 @@ export class ComfyUI { constructor(app) { this.app = app; this.dialog = new ComfyDialog(); - this.settings = new ComfySettingsDialog(); + this.settings = new ComfySettingsDialog(app); this.batchCount = 1; this.lastQueueSize = 0; diff --git a/web/scripts/ui/dialog.js b/web/scripts/ui/dialog.js new file mode 100644 index 000000000..aee93b3c8 --- /dev/null +++ b/web/scripts/ui/dialog.js @@ -0,0 +1,32 @@ +import { $el } from "../ui.js"; + +export class ComfyDialog { + constructor() { + this.element = $el("div.comfy-modal", { parent: document.body }, [ + $el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]), + ]); + } + + createButtons() { + return [ + $el("button", { + type: "button", + textContent: "Close", + onclick: () => this.close(), + }), + ]; + } + + close() { + this.element.style.display = "none"; + } + + show(html) { + if (typeof html === "string") { + this.textElement.innerHTML = html; + } else { + this.textElement.replaceChildren(html); + } + this.element.style.display = "flex"; + } +} diff --git a/web/scripts/ui/settings.js b/web/scripts/ui/settings.js new file mode 100644 index 000000000..1cdba5cfe --- /dev/null +++ b/web/scripts/ui/settings.js @@ -0,0 +1,307 @@ +import { $el } from "../ui.js"; +import { api } from "../api.js"; +import { ComfyDialog } from "./dialog.js"; + +export class ComfySettingsDialog extends ComfyDialog { + constructor(app) { + super(); + this.app = app; + this.settingsValues = {}; + this.settingsLookup = {}; + this.element = $el( + "dialog", + { + id: "comfy-settings-dialog", + parent: document.body, + }, + [ + $el("table.comfy-modal-content.comfy-table", [ + $el("caption", { textContent: "Settings" }), + $el("tbody", { $: (tbody) => (this.textElement = tbody) }), + $el("button", { + type: "button", + textContent: "Close", + style: { + cursor: "pointer", + }, + onclick: () => { + this.element.close(); + }, + }), + ]), + ] + ); + } + + get settings() { + return Object.values(this.settingsLookup); + } + + async load() { + if (this.app.storageLocation === "browser") { + this.settingsValues = localStorage; + } else { + this.settingsValues = await api.getSettings(); + } + + // Trigger onChange for any settings added before load + for (const id in this.settingsLookup) { + this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]); + } + } + + getId(id) { + if (this.app.storageLocation === "browser") { + id = "Comfy.Settings." + id; + } + return id; + } + + getSettingValue(id, defaultValue) { + let value = this.settingsValues[this.getId(id)]; + if(value != null) { + if(this.app.storageLocation === "browser") { + try { + value = JSON.parse(value); + } catch (error) { + } + } + } + return value ?? defaultValue; + } + + async setSettingValueAsync(id, value) { + const json = JSON.stringify(value); + localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage + + let oldValue = this.getSettingValue(id, undefined); + this.settingsValues[this.getId(id)] = value; + + if (id in this.settingsLookup) { + this.settingsLookup[id].onChange?.(value, oldValue); + } + + await api.storeSetting(id, value); + } + + setSettingValue(id, value) { + this.setSettingValueAsync(id, value).catch((err) => { + alert(`Error saving setting '${id}'`); + console.error(err); + }); + } + + addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) { + if (!id) { + throw new Error("Settings must have an ID"); + } + + if (id in this.settingsLookup) { + throw new Error(`Setting ${id} of type ${type} must have a unique ID.`); + } + + let skipOnChange = false; + let value = this.getSettingValue(id); + if (value == null) { + if (this.app.isNewUserSession) { + // Check if we have a localStorage value but not a setting value and we are a new user + const localValue = localStorage["Comfy.Settings." + id]; + if (localValue) { + value = JSON.parse(localValue); + this.setSettingValue(id, value); // Store on the server + } + } + if (value == null) { + value = defaultValue; + } + } + + // Trigger initial setting of value + if (!skipOnChange) { + onChange?.(value, undefined); + } + + this.settingsLookup[id] = { + id, + onChange, + name, + render: () => { + const setter = (v) => { + if (onChange) { + onChange(v, value); + } + + this.setSettingValue(id, v); + value = v; + }; + value = this.getSettingValue(id, defaultValue); + + let element; + const htmlID = id.replaceAll(".", "-"); + + const labelCell = $el("td", [ + $el("label", { + for: htmlID, + classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""], + textContent: name, + }), + ]); + + if (typeof type === "function") { + element = type(name, setter, value, attrs); + } else { + switch (type) { + case "boolean": + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + id: htmlID, + type: "checkbox", + checked: value, + onchange: (event) => { + const isChecked = event.target.checked; + if (onChange !== undefined) { + onChange(isChecked); + } + this.setSettingValue(id, isChecked); + }, + }), + ]), + ]); + break; + case "number": + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + type, + value, + id: htmlID, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs, + }), + ]), + ]); + break; + case "slider": + element = $el("tr", [ + labelCell, + $el("td", [ + $el( + "div", + { + style: { + display: "grid", + gridAutoFlow: "column", + }, + }, + [ + $el("input", { + ...attrs, + value, + type: "range", + oninput: (e) => { + setter(e.target.value); + e.target.nextElementSibling.value = e.target.value; + }, + }), + $el("input", { + ...attrs, + value, + id: htmlID, + type: "number", + style: { maxWidth: "4rem" }, + oninput: (e) => { + setter(e.target.value); + e.target.previousElementSibling.value = e.target.value; + }, + }), + ] + ), + ]), + ]); + break; + case "combo": + element = $el("tr", [ + labelCell, + $el("td", [ + $el( + "select", + { + oninput: (e) => { + setter(e.target.value); + }, + }, + (typeof options === "function" ? options(value) : options || []).map((opt) => { + if (typeof opt === "string") { + opt = { text: opt }; + } + const v = opt.value ?? opt.text; + return $el("option", { + value: v, + textContent: opt.text, + selected: value + "" === v + "", + }); + }) + ), + ]), + ]); + break; + case "text": + default: + if (type !== "text") { + console.warn(`Unsupported setting type '${type}, defaulting to text`); + } + + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + value, + id: htmlID, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs, + }), + ]), + ]); + break; + } + } + if (tooltip) { + element.title = tooltip; + } + + return element; + }, + }; + + const self = this; + return { + get value() { + return self.getSettingValue(id, defaultValue); + }, + set value(v) { + self.setSettingValue(id, v); + }, + }; + } + + show() { + this.textElement.replaceChildren( + $el( + "tr", + { + style: { display: "none" }, + }, + [$el("th"), $el("th", { style: { width: "33%" } })] + ), + ...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render()) + ); + this.element.showModal(); + } +} diff --git a/web/scripts/ui/spinner.css b/web/scripts/ui/spinner.css new file mode 100644 index 000000000..56da6072e --- /dev/null +++ b/web/scripts/ui/spinner.css @@ -0,0 +1,34 @@ +.lds-ring { + display: inline-block; + position: relative; + width: 1em; + height: 1em; +} +.lds-ring div { + box-sizing: border-box; + display: block; + position: absolute; + width: 100%; + height: 100%; + border: 0.15em solid #fff; + border-radius: 50%; + animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite; + border-color: #fff transparent transparent transparent; +} +.lds-ring div:nth-child(1) { + animation-delay: -0.45s; +} +.lds-ring div:nth-child(2) { + animation-delay: -0.3s; +} +.lds-ring div:nth-child(3) { + animation-delay: -0.15s; +} +@keyframes lds-ring { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} diff --git a/web/scripts/ui/spinner.js b/web/scripts/ui/spinner.js new file mode 100644 index 000000000..d049786f6 --- /dev/null +++ b/web/scripts/ui/spinner.js @@ -0,0 +1,9 @@ +import { addStylesheet } from "../utils.js"; + +addStylesheet(import.meta.url); + +export function createSpinner() { + const div = document.createElement("div"); + div.innerHTML = `
`; + return div.firstElementChild; +} diff --git a/web/scripts/ui/userSelection.css b/web/scripts/ui/userSelection.css new file mode 100644 index 000000000..35c9d6614 --- /dev/null +++ b/web/scripts/ui/userSelection.css @@ -0,0 +1,135 @@ +.comfy-user-selection { + width: 100vw; + height: 100vh; + position: absolute; + top: 0; + left: 0; + z-index: 999; + display: flex; + align-items: center; + justify-content: center; + font-family: sans-serif; + background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color)); +} + +.comfy-user-selection-inner { + background: var(--comfy-menu-bg); + margin-top: -30vh; + padding: 20px 40px; + border-radius: 10px; + min-width: 365px; + position: relative; + box-shadow: 0 0 20px rgba(0, 0, 0, 0.3); +} + +.comfy-user-selection-inner form { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; +} + +.comfy-user-selection-inner h1 { + margin: 10px 0 30px 0; + font-weight: normal; +} + +.comfy-user-selection-inner label { + display: flex; + flex-direction: column; + width: 100%; +} + +.comfy-user-selection input, +.comfy-user-selection select { + background-color: var(--comfy-input-bg); + color: var(--input-text); + border: 0; + border-radius: 5px; + padding: 5px; + margin-top: 10px; +} + +.comfy-user-selection input::placeholder { + color: var(--descrip-text); + opacity: 1; +} + +.comfy-user-existing { + width: 100%; +} + +.no-users .comfy-user-existing { + display: none; +} + +.comfy-user-selection-inner .or-separator { + margin: 10px 0; + padding: 10px; + display: block; + text-align: center; + width: 100%; + color: var(--descrip-text); +} + +.comfy-user-selection-inner .or-separator { + overflow: hidden; + text-align: center; + margin-left: -10px; +} + +.comfy-user-selection-inner .or-separator::before, +.comfy-user-selection-inner .or-separator::after { + content: ""; + background-color: var(--border-color); + position: relative; + height: 1px; + vertical-align: middle; + display: inline-block; + width: calc(50% - 20px); + top: -1px; +} + +.comfy-user-selection-inner .or-separator::before { + right: 10px; + margin-left: -50%; +} + +.comfy-user-selection-inner .or-separator::after { + left: 10px; + margin-right: -50%; +} + +.comfy-user-selection-inner section { + width: 100%; + padding: 10px; + margin: -10px; + transition: background-color 0.2s; +} + +.comfy-user-selection-inner section.selected { + background: var(--border-color); + border-radius: 5px; +} + +.comfy-user-selection-inner footer { + display: flex; + flex-direction: column; + align-items: center; + margin-top: 20px; +} + +.comfy-user-selection-inner .comfy-user-error { + color: var(--error-text); + margin-bottom: 10px; +} + +.comfy-user-button-next { + font-size: 16px; + padding: 6px 10px; + width: 100px; + display: flex; + gap: 5px; + align-items: center; + justify-content: center; +} \ No newline at end of file diff --git a/web/scripts/ui/userSelection.js b/web/scripts/ui/userSelection.js new file mode 100644 index 000000000..f9f1ca807 --- /dev/null +++ b/web/scripts/ui/userSelection.js @@ -0,0 +1,114 @@ +import { api } from "../api.js"; +import { $el } from "../ui.js"; +import { addStylesheet } from "../utils.js"; +import { createSpinner } from "./spinner.js"; + +export class UserSelectionScreen { + async show(users, user) { + // This will rarely be hit so move the loading to on demand + await addStylesheet(import.meta.url); + const userSelection = document.getElementById("comfy-user-selection"); + userSelection.style.display = ""; + return new Promise((resolve) => { + const input = userSelection.getElementsByTagName("input")[0]; + const select = userSelection.getElementsByTagName("select")[0]; + const inputSection = input.closest("section"); + const selectSection = select.closest("section"); + const form = userSelection.getElementsByTagName("form")[0]; + const error = userSelection.getElementsByClassName("comfy-user-error")[0]; + const button = userSelection.getElementsByClassName("comfy-user-button-next")[0]; + + let inputActive = null; + input.addEventListener("focus", () => { + inputSection.classList.add("selected"); + selectSection.classList.remove("selected"); + inputActive = true; + }); + select.addEventListener("focus", () => { + inputSection.classList.remove("selected"); + selectSection.classList.add("selected"); + inputActive = false; + select.style.color = ""; + }); + select.addEventListener("blur", () => { + if (!select.value) { + select.style.color = "var(--descrip-text)"; + } + }); + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + if (inputActive == null) { + error.textContent = "Please enter a username or select an existing user."; + } else if (inputActive) { + const username = input.value.trim(); + if (!username) { + error.textContent = "Please enter a username."; + return; + } + + // Create new user + input.disabled = select.disabled = input.readonly = select.readonly = true; + const spinner = createSpinner(); + button.prepend(spinner); + try { + const resp = await api.createUser(username); + if (resp.status >= 300) { + let message = "Error creating user: " + resp.status + " " + resp.statusText; + try { + const res = await resp.json(); + if(res.error) { + message = res.error; + } + } catch (error) { + } + throw new Error(message); + } + + resolve({ username, userId: await resp.json(), created: true }); + } catch (err) { + spinner.remove(); + error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred."; + input.disabled = select.disabled = input.readonly = select.readonly = false; + return; + } + } else if (!select.value) { + error.textContent = "Please select an existing user."; + return; + } else { + resolve({ username: users[select.value], userId: select.value, created: false }); + } + }); + + if (user) { + const name = localStorage["Comfy.userName"]; + if (name) { + input.value = name; + } + } + if (input.value) { + // Focus the input, do this separately as sometimes browsers like to fill in the value + input.focus(); + } + + const userIds = Object.keys(users ?? {}); + if (userIds.length) { + for (const u of userIds) { + $el("option", { textContent: users[u], value: u, parent: select }); + } + select.style.color = "var(--descrip-text)"; + + if (select.value) { + // Focus the select, do this separately as sometimes browsers like to fill in the value + select.focus(); + } + } else { + userSelection.classList.add("no-users"); + input.focus(); + } + }).then((r) => { + userSelection.remove(); + return r; + }); + } +} diff --git a/web/scripts/utils.js b/web/scripts/utils.js index 401aca9e4..01b988462 100644 --- a/web/scripts/utils.js +++ b/web/scripts/utils.js @@ -1,3 +1,5 @@ +import { $el } from "./ui.js"; + // Simple date formatter const parts = { d: (d) => d.getDate(), @@ -65,3 +67,22 @@ export function applyTextReplacements(app, value) { return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); }); } + +export async function addStylesheet(urlOrFile, relativeTo) { + return new Promise((res, rej) => { + let url; + if (urlOrFile.endsWith(".js")) { + url = urlOrFile.substr(0, urlOrFile.length - 2) + "css"; + } else { + url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString(); + } + $el("link", { + parent: document.head, + rel: "stylesheet", + type: "text/css", + href: url, + onload: res, + onerror: rej, + }); + }); +} diff --git a/web/style.css b/web/style.css index 630eea12e..c57fb7f73 100644 --- a/web/style.css +++ b/web/style.css @@ -121,6 +121,7 @@ body { width: 100%; } +.comfy-btn, .comfy-menu > button, .comfy-menu-btns button, .comfy-menu .comfy-list button, @@ -133,6 +134,7 @@ body { margin-top: 2px; } +.comfy-btn:hover:not(:disabled), .comfy-menu > button:hover, .comfy-menu-btns button:hover, .comfy-menu .comfy-list button:hover, From 2d74fc436058c98e5182d2123700ae244ccbe037 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 Jan 2024 17:08:00 -0500 Subject: [PATCH 11/54] Fix issue with user manager parent dir not being created. --- app/user_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/user_manager.py b/app/user_manager.py index 5a1031744..75485ee42 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -63,7 +63,6 @@ class UserManager(): path = os.path.abspath(os.path.join(user_root, file)) if os.path.commonpath((user_root, path)) != user_root: return None - parent = os.path.join(path, os.pardir) if create_dir and not os.path.exists(parent): os.mkdir(parent) From b3b5ddb07a23b3d070df292c7a7fd6f83dc8fd50 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 Jan 2024 17:08:17 -0500 Subject: [PATCH 12/54] Support I mode images in LoadImageMask. --- nodes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nodes.py b/nodes.py index 1ea0669b1..96d9e51a7 100644 --- a/nodes.py +++ b/nodes.py @@ -1472,6 +1472,8 @@ class LoadImageMask: i = Image.open(image_path) i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) i = i.convert("RGBA") mask = None c = channel[0].upper() From 6a7bc35db845179a26e62534f3d4b789151e52fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 Jan 2024 13:46:52 -0500 Subject: [PATCH 13/54] Use basic attention implementation for small inputs on old pytorch. --- comfy/ldm/modules/attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 309240d5c..fd8888d0e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -351,8 +351,11 @@ else: optimized_attention_masked = optimized_attention def optimized_attention_for_device(device, mask=False, small_input=False): - if small_input and model_management.pytorch_attention_enabled(): - return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + if small_input: + if model_management.pytorch_attention_enabled(): + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + else: + return attention_basic if device == torch.device("cpu"): return attention_sub_quad From 2c80d9acb9a542fef843954370390369093b6ac4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 Jan 2024 15:12:12 -0500 Subject: [PATCH 14/54] Round up to nearest power of 2 in SAG node to fix some resolution issues. --- comfy_extras/nodes_sag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 450ac3eea..b307b4f56 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = math.ceil(math.sqrt(lh * lw / hw1)) + ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape From 1a57423d30da06283dd20a7aee2aadcb58c37e59 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Jan 2024 04:00:49 -0500 Subject: [PATCH 15/54] Fix issue when using multiple t2i adapters with batched images. --- comfy/controlnet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8404054f3..df474201c 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -126,7 +126,10 @@ class ControlBase: if o[i] is None: o[i] = prev_val else: - o[i] += prev_val + if o[i].shape[0] < prev_val.shape[0]: + o[i] = prev_val + o[i] + else: + o[i] += prev_val return out class ControlNet(ControlBase): From b4e915e74560bd2c090f9b4ed6b73b0781b7050e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Jan 2024 04:08:43 -0500 Subject: [PATCH 16/54] Skip SAG when latent is too small. --- comfy_extras/nodes_sag.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index b307b4f56..bbd380807 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -143,6 +143,8 @@ class SelfAttentionGuidance: sigma = args["sigma"] model_options = args["model_options"] x = args["input"] + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding + return cfg_result # create the adversarially blurred image degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) From 10f2609fdd2231faf40346df11372601385b6cbb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 Jan 2024 03:15:27 -0500 Subject: [PATCH 17/54] Add InpaintModelConditioning node. This is an alternative to VAE Encode for inpaint that should work with lower denoise. This is a different take on #2501 --- comfy/model_base.py | 26 ++++++++++++++++--- nodes.py | 62 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index f59526204..52c87ede6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -100,11 +100,29 @@ class BaseModel(torch.nn.Module): if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] - denoise_mask = kwargs.get("denoise_mask", None) - latent_image = kwargs.get("latent_image", None) + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + concat_latent_image = kwargs.get("concat_latent_image", None) + if concat_latent_image is None: + concat_latent_image = kwargs.get("latent_image", None) + else: + concat_latent_image = self.process_latent_in(concat_latent_image) + noise = kwargs.get("noise", None) device = kwargs["device"] + if concat_latent_image.shape[1:] != noise.shape[1:]: + concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) + + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] + + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -117,9 +135,9 @@ class BaseModel(torch.nn.Module): for ck in concat_keys: if denoise_mask is not None: if ck == "mask": - cond_concat.append(denoise_mask[:,:1].to(device)) + cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) diff --git a/nodes.py b/nodes.py index 96d9e51a7..6c7317b69 100644 --- a/nodes.py +++ b/nodes.py @@ -359,6 +359,62 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class InpaintModelConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + "mask": ("MASK", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, positive, negative, pixels, vae, mask): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + + orig_pixels = pixels + pixels = orig_pixels.clone() + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + + m = (1.0 - mask.round()).squeeze(1) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= m + pixels[:,:,:,i] += 0.5 + concat_latent = vae.encode(pixels) + orig_latent = vae.encode(orig_pixels) + + out_latent = {} + + out_latent["samples"] = orig_latent + out_latent["noise_mask"] = mask + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + d["concat_mask"] = mask + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + + class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -1628,10 +1684,11 @@ class ImagePadForOutpaint: def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() - new_image = torch.zeros( + new_image = torch.ones( (d1, d2 + top + bottom, d3 + left + right, d4), dtype=torch.float32, - ) + ) * 0.5 + new_image[:, top:top + d2, left:left + d3, :] = image mask = torch.ones( @@ -1723,6 +1780,7 @@ NODE_CLASS_MAPPINGS = { "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, + "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, From 977eda19a6471fbff253dc92c3c2f1a4a67b1793 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 Jan 2024 03:29:58 -0500 Subject: [PATCH 18/54] Don't round noise mask. --- comfy/sample.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/sample.py b/comfy/sample.py index 4b0d15c49..5c8a7d130 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None): def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) From 4ab0392f707b26767eebb6e7f00e55a474732d63 Mon Sep 17 00:00:00 2001 From: TFWol <9045213+TFWol@users.noreply.github.com> Date: Thu, 11 Jan 2024 06:34:33 -0800 Subject: [PATCH 19/54] Resolved crashing nodes caused by `FileNotFoundError` during directory traversal - Implemented a `try-except` block in the `recursive_search` function to handle `FileNotFoundError` gracefully. - When encountering a file or directory path that cannot be accessed (causing `FileNotFoundError`), the code now logs a warning and skips processing for that specific path instead of crashing the node (CheckpointLoaderSimple was usually the first to break). This allows the rest of the directory traversal to proceed without interruption. --- folder_paths.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 641e5f0b2..4abed783b 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -138,15 +138,20 @@ def recursive_search(directory, excluded_dir_names=None): excluded_dir_names = [] result = [] - dirs = {directory: os.path.getmtime(directory)} + dirs = {} for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) result.append(relative_path) + for d in subdirs: path = os.path.join(dirpath, d) - dirs[path] = os.path.getmtime(path) + try: + dirs[path] = os.path.getmtime(path) + except FileNotFoundError: + print(f"Warning: Unable to access {path}. Skipping this path.") + continue return result, dirs def filter_files_extensions(files, extensions): From 1b3d65bd84c8026dea234643861491279886218c Mon Sep 17 00:00:00 2001 From: realazthat Date: Thu, 11 Jan 2024 08:38:18 -0500 Subject: [PATCH 20/54] Add error, status to /history endpoint --- execution.py | 70 +++++++++++++++++++++++++++++++++------------------- main.py | 7 +++++- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/execution.py b/execution.py index 260a08970..554e8dc12 100644 --- a/execution.py +++ b/execution.py @@ -8,6 +8,7 @@ import heapq import traceback import gc import inspect +from typing import List, Literal, NamedTuple, Optional import torch import nodes @@ -275,8 +276,15 @@ class PromptExecutor: self.outputs = {} self.object_storage = {} self.outputs_ui = {} + self.status_notes = [] + self.success = True self.old_prompt = {} + def add_note(self, event, data, broadcast: bool): + self.status_notes.append((event, data)) + if self.server.client_id is not None or broadcast: + self.server.send_sync(event, data, self.server.client_id) + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] @@ -290,23 +298,22 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), } - self.server.send_sync("execution_interrupted", mes, self.server.client_id) + self.add_note("execution_interrupted", mes, broadcast=True) else: - if self.server.client_id is not None: - mes = { - "prompt_id": prompt_id, - "node_id": node_id, - "node_type": class_type, - "executed": list(executed), - - "exception_message": error["exception_message"], - "exception_type": error["exception_type"], - "traceback": error["traceback"], - "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], - } - self.server.send_sync("execution_error", mes, self.server.client_id) + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.add_note("execution_error", mes, broadcast=False) + # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: @@ -327,8 +334,7 @@ class PromptExecutor: else: self.server.client_id = None - if self.server.client_id is not None: - self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): #delete cached outputs if nodes don't exist for them @@ -361,8 +367,9 @@ class PromptExecutor: del d comfy.model_management.cleanup_models() - if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + self.add_note("execution_cached", + { "nodes": list(current_outputs) , "prompt_id": prompt_id}, + broadcast=False) executed = set() output_node_id = None to_execute = [] @@ -378,8 +385,8 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if success is not True: + self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) + if self.success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break @@ -731,14 +738,27 @@ class PromptQueue: self.server.queue_updated() return (item, i) - def task_done(self, item_id, outputs): + class ExecutionStatus(NamedTuple): + status_str: Literal['success', 'error'] + completed: bool + notes: List[str] + + def task_done(self, item_id, outputs, + status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] + + status_dict: dict|None = None + if status is not None: + status_dict = copy.deepcopy(status._asdict()) + + self.history[prompt[1]] = { + "prompt": prompt, + "outputs": copy.deepcopy(outputs), + 'status': status_dict, + } self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 45d5d41b3..a40ad2a4c 100644 --- a/main.py +++ b/main.py @@ -110,7 +110,12 @@ def prompt_worker(q, server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True - q.task_done(item_id, e.outputs_ui) + q.task_done(item_id, + e.outputs_ui, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + notes=e.status_notes)) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) From d4edd9bfa8c227175c668bb83f2c80cef43fc53e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 Jan 2024 15:13:38 -0500 Subject: [PATCH 21/54] Fix hypertile issue with high depths. --- comfy_extras/nodes_hypertile.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index e7446b2e5..544ae1b22 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -37,24 +37,24 @@ class HyperTile: def patch(self, model, tile_size, swap_size, max_depth, scale_depth): model_channels = model.model.model_config.unet_config["model_channels"] - apply_to = set() - temp = model_channels - for x in range(max_depth + 1): - apply_to.add(temp) - temp *= 2 - latent_tile_size = max(32, tile_size) // 8 self.temp = None def hypertile_in(q, k, v, extra_options): - if q.shape[-1] in apply_to: + model_chans = q.shape[-2] + orig_shape = extra_options['original_shape'] + apply_to = [] + for i in range(max_depth + 1): + apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) + + if model_chans in apply_to: shape = extra_options["original_shape"] aspect_ratio = shape[-1] / shape[-2] hw = q.size(1) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) - factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 + factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 nh = random_divisor(h, latent_tile_size * factor, swap_size) nw = random_divisor(w, latent_tile_size * factor, swap_size) From 53c8a99e6c00b5e20425100f6680cd9ea2652218 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 Jan 2024 17:21:40 -0500 Subject: [PATCH 22/54] Make server storage the default. Remove --server-storage argument. --- app/user_manager.py | 2 +- comfy/cli_args.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/app/user_manager.py b/app/user_manager.py index 75485ee42..209094af1 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -94,7 +94,7 @@ class UserManager(): else: user_dir = self.get_request_user_filepath(request, None, create_dir=False) return web.json_response({ - "storage": "server" if args.server_storage else "browser", + "storage": "server", "migrated": os.path.exists(user_dir) }) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index c02bbf2ba..b4bbfbfab 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,8 +112,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") -parser.add_argument("--server-storage", action="store_true", help="Saves settings and other user configuration on the server instead of in browser storage.") -parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage. If enabled, server-storage will be unconditionally enabled.") +parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") if comfy.options.args_parsing: args = parser.parse_args() @@ -125,6 +124,3 @@ if args.windows_standalone_build: if args.disable_auto_launch: args.auto_launch = False - -if args.multi_user: - args.server_storage = True \ No newline at end of file From bcc0bde2afa7028939af737b46533b465baeb22d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 12 Jan 2024 17:21:22 -0500 Subject: [PATCH 23/54] Clear status notes on execution start. --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 554e8dc12..752e1d5aa 100644 --- a/execution.py +++ b/execution.py @@ -334,6 +334,7 @@ class PromptExecutor: else: self.server.client_id = None + self.status_notes = [] self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): From 56d9496b18baa5946834d1982908df0091e1c925 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 12 Jan 2024 18:17:06 -0500 Subject: [PATCH 24/54] Rename status notes to status messages. I think message describes them better. --- execution.py | 18 +++++++++--------- main.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 752e1d5aa..e91e9a410 100644 --- a/execution.py +++ b/execution.py @@ -276,12 +276,12 @@ class PromptExecutor: self.outputs = {} self.object_storage = {} self.outputs_ui = {} - self.status_notes = [] + self.status_messages = [] self.success = True self.old_prompt = {} - def add_note(self, event, data, broadcast: bool): - self.status_notes.append((event, data)) + def add_message(self, event, data, broadcast: bool): + self.status_messages.append((event, data)) if self.server.client_id is not None or broadcast: self.server.send_sync(event, data, self.server.client_id) @@ -298,7 +298,7 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), } - self.add_note("execution_interrupted", mes, broadcast=True) + self.add_message("execution_interrupted", mes, broadcast=True) else: mes = { "prompt_id": prompt_id, @@ -312,7 +312,7 @@ class PromptExecutor: "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } - self.add_note("execution_error", mes, broadcast=False) + self.add_message("execution_error", mes, broadcast=False) # Next, remove the subsequent outputs since they will not be executed to_delete = [] @@ -334,8 +334,8 @@ class PromptExecutor: else: self.server.client_id = None - self.status_notes = [] - self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False) + self.status_messages = [] + self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): #delete cached outputs if nodes don't exist for them @@ -368,7 +368,7 @@ class PromptExecutor: del d comfy.model_management.cleanup_models() - self.add_note("execution_cached", + self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) executed = set() @@ -742,7 +742,7 @@ class PromptQueue: class ExecutionStatus(NamedTuple): status_str: Literal['success', 'error'] completed: bool - notes: List[str] + messages: List[str] def task_done(self, item_id, outputs, status: Optional['PromptQueue.ExecutionStatus']): diff --git a/main.py b/main.py index a40ad2a4c..69d9bce6c 100644 --- a/main.py +++ b/main.py @@ -115,7 +115,7 @@ def prompt_worker(q, server): status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - notes=e.status_notes)) + messages=e.status_messages)) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) From df49a727ff0c786adf316b17632124076e847e16 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 13 Jan 2024 17:00:30 +0000 Subject: [PATCH 25/54] Fix modifiers triggering key down checks --- web/extensions/core/undoRedo.js | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index 3cb137520..ff976c74c 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -106,6 +106,7 @@ const bindInput = (activeEl) => { } }; +let keyIgnored = false; window.addEventListener( "keydown", (e) => { @@ -116,6 +117,9 @@ window.addEventListener( return; } + keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta"; + if (keyIgnored) return; + // Check if this is a ctrl+z ctrl+y if (await undoRedo(e)) return; @@ -127,6 +131,13 @@ window.addEventListener( true ); +window.addEventListener("keyup", (e) => { + if (keyIgnored) { + keyIgnored = false; + checkState(); + } +}); + // Handle clicking DOM elements (e.g. widgets) window.addEventListener("mouseup", () => { checkState(); From 32034217ae14bb0dac3108b1e4e8454f036a9b2b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:57:47 +0000 Subject: [PATCH 26/54] add setting to change control after generate to run before --- web/scripts/app.js | 11 +++++++- web/scripts/widgets.js | 64 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 72f9e8603..d131045d7 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1,5 +1,5 @@ import { ComfyLogging } from "./logging.js"; -import { ComfyWidgets } from "./widgets.js"; +import { ComfyWidgets, initWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; @@ -1420,6 +1420,7 @@ export class ComfyApp { await this.#invokeExtensionsAsync("init"); await this.registerNodes(); + initWidgets(this); // Load previous workflow let restored = false; @@ -1774,6 +1775,14 @@ export class ComfyApp { */ async graphToPrompt() { for (const outerNode of this.graph.computeExecutionOrder(false)) { + if (outerNode.widgets) { + for (const widget of outerNode.widgets) { + // Allow widgets to run callbacks before a prompt has been queued + // e.g. random seed before every gen + widget.beforeQueued?.(); + } + } + const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; for (const node of innerNodes) { if (node.isVirtualNode) { diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index e2e21164d..36429266a 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,6 +1,19 @@ import { api } from "./api.js" import "./domWidget.js"; +let controlValueRunBefore = false; +function updateControlWidgetLabel(widget) { + let replacement = "after"; + let find = "before"; + if (controlValueRunBefore) { + [find, replacement] = [replacement, find] + } + widget.label = (widget.label ?? widget.name).replace(find, replacement); +} + +const IS_CONTROL_WIDGET = Symbol(); +const HAS_EXECUTED = Symbol(); + function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -62,6 +75,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando serialize: false, // Don't include this in prompt. } ); + valueControl[IS_CONTROL_WIDGET] = true; + updateControlWidgetLabel(valueControl); widgets.push(valueControl); const isCombo = targetWidget.type === "combo"; @@ -76,10 +91,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando serialize: false, // Don't include this in prompt. } ); + updateControlWidgetLabel(comboFilter); + widgets.push(comboFilter); } - valueControl.afterQueued = () => { + const applyWidgetControl = () => { var v = valueControl.value; if (isCombo && v !== "fixed") { @@ -159,6 +176,23 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando targetWidget.callback(targetWidget.value); } }; + + valueControl.beforeQueued = () => { + if (controlValueRunBefore) { + // Don't run on first execution + if (valueControl[HAS_EXECUTED]) { + applyWidgetControl(); + } + } + valueControl[HAS_EXECUTED] = true; + }; + + valueControl.afterQueued = () => { + if (!controlValueRunBefore) { + applyWidgetControl(); + } + }; + return widgets; }; @@ -224,6 +258,34 @@ function isSlider(display, app) { return (display==="slider") ? "slider" : "number" } +export function initWidgets(app) { + app.ui.settings.addSetting({ + id: "Comfy.WidgetControlMode", + name: "Widget Value Control Mode", + type: "combo", + defaultValue: "after", + options: ["before", "after"], + tooltip: "Controls when widget values are updated (randomize/increment/decrement), either before the prompt is queued or after.", + onChange(value) { + controlValueRunBefore = value === "before"; + for (const n of app.graph._nodes) { + if (!n.widgets) continue; + for (const w of n.widgets) { + if (w[IS_CONTROL_WIDGET]) { + updateControlWidgetLabel(w); + if (w.linkedWidgets) { + for (const l of w.linkedWidgets) { + updateControlWidgetLabel(l); + } + } + } + } + } + app.graph.setDirtyCanvas(true); + }, + }); +} + export const ComfyWidgets = { "INT:seed": seedWidget, "INT:noise_seed": seedWidget, From 8e916735c01a238251eb1dc35e9eef59284f0ca8 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:57:59 +0000 Subject: [PATCH 27/54] export function --- web/scripts/widgets.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 36429266a..0529b1d80 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -2,7 +2,7 @@ import { api } from "./api.js" import "./domWidget.js"; let controlValueRunBefore = false; -function updateControlWidgetLabel(widget) { +export function updateControlWidgetLabel(widget) { let replacement = "after"; let find = "before"; if (controlValueRunBefore) { From 18511dd581cf8da701cc60af07bb947fa4a79656 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 13 Jan 2024 20:43:20 +0000 Subject: [PATCH 28/54] Manage group nodes (#2455) * wip group manage * prototyping ui * tweaks * wip * wip * more wip * fixes add deletion * Fix tests * fixes * Remove test code * typo * fix crash when link is invalid --- web/extensions/core/groupNode.js | 161 ++++++--- web/extensions/core/groupNodeManage.css | 149 +++++++++ web/extensions/core/groupNodeManage.js | 422 ++++++++++++++++++++++++ web/jsconfig.json | 3 +- web/scripts/domWidget.js | 3 +- web/scripts/ui.js | 20 +- web/scripts/ui/draggableList.js | 287 ++++++++++++++++ web/style.css | 11 +- 8 files changed, 1007 insertions(+), 49 deletions(-) create mode 100644 web/extensions/core/groupNodeManage.css create mode 100644 web/extensions/core/groupNodeManage.js create mode 100644 web/scripts/ui/draggableList.js diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 4cf1f7621..7feb2becc 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -1,6 +1,7 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; import { mergeIfValid } from "./widgetInputs.js"; +import { ManageGroupDialog } from "./groupNodeManage.js"; const GROUP = Symbol(); @@ -61,11 +62,7 @@ class GroupNodeBuilder { ); return; case Workflow.InUse.Registered: - if ( - !confirm( - "An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?" - ) - ) { + if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) { return; } break; @@ -151,6 +148,8 @@ export class GroupNodeConfig { this.primitiveDefs = {}; this.widgetToPrimitive = {}; this.primitiveToWidget = {}; + this.nodeInputs = {}; + this.outputVisibility = []; } async registerType(source = "workflow") { @@ -158,6 +157,7 @@ export class GroupNodeConfig { output: [], output_name: [], output_is_list: [], + output_is_hidden: [], name: source + "/" + this.name, display_name: this.name, category: "group nodes" + ("/" + source), @@ -277,8 +277,7 @@ export class GroupNodeConfig { } if (input.widget) { const targetDef = globalDefs[node.type]; - const targetWidget = - targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; const widget = [targetWidget[0], config]; const res = mergeIfValid( @@ -330,7 +329,8 @@ export class GroupNodeConfig { } getInputConfig(node, inputName, seenInputs, config, extra) { - let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName]; + let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; let key = name; let prefix = ""; // Special handling for primitive to include the title if it is set rather than just "value" @@ -356,7 +356,7 @@ export class GroupNodeConfig { config = [config[0], { ...config[1], ...extra }]; } - return { name, config }; + return { name, config, customConfig }; } processWidgetInputs(inputs, node, inputNames, seenInputs) { @@ -366,9 +366,7 @@ export class GroupNodeConfig { for (const inputName of inputNames) { let widgetType = app.getWidgetType(inputs[inputName], inputName); if (widgetType) { - const convertedIndex = node.inputs?.findIndex( - (inp) => inp.name === inputName && inp.widget?.name === inputName - ); + const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName); if (convertedIndex > -1) { // This widget has been converted to a widget // We need to store this in the correct position so link ids line up @@ -424,6 +422,7 @@ export class GroupNodeConfig { } processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) { + this.nodeInputs[node.index] = {}; for (let i = 0; i < slots.length; i++) { const inputName = slots[i]; if (linksTo[i]) { @@ -432,7 +431,11 @@ export class GroupNodeConfig { continue; } - const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + + this.nodeInputs[node.index][inputName] = name; + if(customConfig?.visible === false) continue; + this.nodeDef.input.required[name] = config; inputMap[i] = this.inputCount++; } @@ -452,6 +455,7 @@ export class GroupNodeConfig { const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], { defaultInput: true, }); + this.nodeDef.input.required[name] = config; this.newToOldWidgetMap[name] = { node, inputName }; @@ -477,9 +481,7 @@ export class GroupNodeConfig { this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); // Converted inputs have to be processed after all other nodes as they'll be at the end of the list - this.#convertedToProcess.push(() => - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) - ); + this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)); return inputMapping; } @@ -490,8 +492,12 @@ export class GroupNodeConfig { // Add outputs for (let outputId = 0; outputId < def.output.length; outputId++) { const linksFrom = this.linksFrom[node.index]; - if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) { - // This output is linked internally so we can skip it + // If this output is linked internally we flag it to hide + const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]; + const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId]; + const visible = customConfig?.visible ?? !hasLink; + this.outputVisibility.push(visible); + if (!visible) { continue; } @@ -500,11 +506,15 @@ export class GroupNodeConfig { this.nodeDef.output.push(def.output[outputId]); this.nodeDef.output_is_list.push(def.output_is_list[outputId]); - let label = def.output_name?.[outputId] ?? def.output[outputId]; - const output = node.outputs.find((o) => o.name === label); - if (output?.label) { - label = output.label; + let label = customConfig?.name; + if (!label) { + label = def.output_name?.[outputId] ?? def.output[outputId]; + const output = node.outputs.find((o) => o.name === label); + if (output?.label) { + label = output.label; + } } + let name = label; if (name in seenOutputs) { const prefix = `${node.title ?? node.type} `; @@ -677,6 +687,25 @@ export class GroupNodeHandler { return this.innerNodes; }; + this.node.recreate = async () => { + const id = this.node.id; + const sz = this.node.size; + const nodes = this.node.convertToNodes(); + + const groupNode = LiteGraph.createNode(this.node.type); + groupNode.id = id; + + // Reuse the existing nodes for this instance + groupNode.setInnerNodes(nodes); + groupNode[GROUP].populateWidgets(); + app.graph.add(groupNode); + groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])]; + + // Remove all converted nodes and relink them + groupNode[GROUP].replaceNodes(nodes); + return groupNode; + }; + this.node.convertToNodes = () => { const addInnerNodes = () => { const backup = localStorage.getItem("litegrapheditor_clipboard"); @@ -769,6 +798,7 @@ export class GroupNodeHandler { const slot = node.inputs[groupSlotId]; if (slot.link == null) continue; const link = app.graph.links[slot.link]; + if (!link) continue; // connect this node output to the input of another node const originNode = app.graph.getNodeById(link.origin_id); originNode.connect(link.origin_slot, newNode, +innerInputId); @@ -806,12 +836,23 @@ export class GroupNodeHandler { let optionIndex = options.findIndex((o) => o.content === "Outputs"); if (optionIndex === -1) optionIndex = options.length; else optionIndex++; - options.splice(optionIndex, 0, null, { - content: "Convert to nodes", - callback: () => { - return this.convertToNodes(); + options.splice( + optionIndex, + 0, + null, + { + content: "Convert to nodes", + callback: () => { + return this.convertToNodes(); + }, }, - }); + { + content: "Manage Group Node", + callback: () => { + new ManageGroupDialog(app).show(this.type); + }, + } + ); }; // Draw custom collapse icon to identity this as a group @@ -865,6 +906,28 @@ export class GroupNodeHandler { return onExecutionStart?.apply(this, arguments); }; + const self = this; + const onNodeCreated = this.node.onNodeCreated; + this.node.onNodeCreated = function () { + const config = self.groupData.nodeData.config; + if (config) { + for (const n in config) { + const inputs = config[n]?.input; + for (const w in inputs) { + if (inputs[w].visible !== false) continue; + const widgetName = self.groupData.oldToNewWidgetMap[n][w]; + const widget = this.widgets.find((w) => w.name === widgetName); + if (widget) { + widget.type = "hidden"; + widget.computeSize = () => [0, -4]; + } + } + } + } + + return onNodeCreated?.apply(this, arguments); + }; + function handleEvent(type, getId, getEvent) { const handler = ({ detail }) => { const id = getId(detail); @@ -927,13 +990,15 @@ export class GroupNodeHandler { continue; } else if (innerNode.type === "Reroute") { const rerouteLinks = this.groupData.linksFrom[old.node.index]; - for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { - const node = this.innerNodes[targetNodeId]; - const input = node.inputs[targetSlot]; - if (input.widget) { - const widget = node.widgets?.find((w) => w.name === input.widget.name); - if (widget) { - widget.value = newValue; + if (rerouteLinks) { + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } } } } @@ -975,7 +1040,7 @@ export class GroupNodeHandler { const [, , targetNodeId, targetNodeSlot] = link; const targetNode = this.groupData.nodeData.nodes[targetNodeId]; const inputs = targetNode.inputs; - const targetWidget = inputs?.[targetNodeSlot].widget; + const targetWidget = inputs?.[targetNodeSlot]?.widget; if (!targetWidget) return; const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); @@ -983,13 +1048,12 @@ export class GroupNodeHandler { if (v == null) return; const widgetName = Object.values(map)[0]; - const widget = this.node.widgets.find(w => w.name === widgetName); - if(widget) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget) { widget.value = v; } } - populateWidgets() { if (!this.node.widgets) return; @@ -1080,7 +1144,7 @@ export class GroupNodeHandler { } static getGroupData(node) { - return node.constructor?.nodeData?.[GROUP]; + return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP]; } static isGroupNode(node) { @@ -1112,7 +1176,7 @@ export class GroupNodeHandler { } function addConvertToGroupOptions() { - function addOption(options, index) { + function addConvertOption(options, index) { const selected = Object.values(app.canvas.selected_nodes ?? {}); const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n)); options.splice(index + 1, null, { @@ -1124,12 +1188,25 @@ function addConvertToGroupOptions() { }); } + function addManageOption(options, index) { + const groups = app.graph.extra?.groupNodes; + const disabled = !groups || !Object.keys(groups).length; + options.splice(index + 1, null, { + content: `Manage Group Nodes`, + disabled, + callback: () => { + new ManageGroupDialog(app).show(); + }, + }); + } + // Add to canvas const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; LGraphCanvas.prototype.getCanvasMenuOptions = function () { const options = getCanvasMenuOptions.apply(this, arguments); const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length; - addOption(options, index); + addConvertOption(options, index); + addManageOption(options, index + 1); return options; }; @@ -1139,7 +1216,7 @@ function addConvertToGroupOptions() { const options = getNodeMenuOptions.apply(this, arguments); if (!GroupNodeHandler.isGroupNode(node)) { const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1; - addOption(options, index); + addConvertOption(options, index); } return options; }; diff --git a/web/extensions/core/groupNodeManage.css b/web/extensions/core/groupNodeManage.css new file mode 100644 index 000000000..5ac89aee3 --- /dev/null +++ b/web/extensions/core/groupNodeManage.css @@ -0,0 +1,149 @@ +.comfy-group-manage { + background: var(--bg-color); + color: var(--fg-color); + padding: 0; + font-family: Arial, Helvetica, sans-serif; + border-color: black; + margin: 20vh auto; + max-height: 60vh; +} +.comfy-group-manage-outer { + max-height: 60vh; + min-width: 500px; + display: flex; + flex-direction: column; +} +.comfy-group-manage-outer > header { + display: flex; + align-items: center; + gap: 10px; + justify-content: space-between; + background: var(--comfy-menu-bg); + padding: 15px 20px; +} +.comfy-group-manage-outer > header select { + background: var(--comfy-input-bg); + border: 1px solid var(--border-color); + color: var(--input-text); + padding: 5px 10px; + border-radius: 5px; +} +.comfy-group-manage h2 { + margin: 0; + font-weight: normal; +} +.comfy-group-manage main { + display: flex; + overflow: hidden; +} +.comfy-group-manage .drag-handle { + font-weight: bold; +} +.comfy-group-manage-list { + border-right: 1px solid var(--comfy-menu-bg); +} +.comfy-group-manage-list ul { + margin: 40px 0 0; + padding: 0; + list-style: none; +} +.comfy-group-manage-list-items { + max-height: 70vh; + overflow-y: scroll; + overflow-x: hidden; +} +.comfy-group-manage-list li { + display: flex; + padding: 10px 20px 10px 10px; + cursor: pointer; + align-items: center; + gap: 5px; +} +.comfy-group-manage-list div { + display: flex; + flex-direction: column; +} +.comfy-group-manage-list li:not(.selected):hover div { + text-decoration: underline; +} +.comfy-group-manage-list li.selected { + background: var(--border-color); +} +.comfy-group-manage-list li span { + opacity: 0.7; + font-size: smaller; +} +.comfy-group-manage-node { + flex: auto; + background: var(--border-color); + display: flex; + flex-direction: column; +} +.comfy-group-manage-node > div { + overflow: auto; +} +.comfy-group-manage-node header { + display: flex; + background: var(--bg-color); + height: 40px; +} +.comfy-group-manage-node header a { + text-align: center; + flex: auto; + border-right: 1px solid var(--comfy-menu-bg); + border-bottom: 1px solid var(--comfy-menu-bg); + padding: 10px; + cursor: pointer; + font-size: 15px; +} +.comfy-group-manage-node header a:last-child { + border-right: none; +} +.comfy-group-manage-node header a:not(.active):hover { + text-decoration: underline; +} +.comfy-group-manage-node header a.active { + background: var(--border-color); + border-bottom: none; +} +.comfy-group-manage-node-page { + display: none; + overflow: auto; +} +.comfy-group-manage-node-page.active { + display: block; +} +.comfy-group-manage-node-page div { + padding: 10px; + display: flex; + align-items: center; + gap: 10px; +} +.comfy-group-manage-node-page input { + border: none; + color: var(--input-text); + background: var(--comfy-input-bg); + padding: 5px 10px; +} +.comfy-group-manage-node-page input[type="text"] { + flex: auto; +} +.comfy-group-manage-node-page label { + display: flex; + gap: 5px; + align-items: center; +} +.comfy-group-manage footer { + border-top: 1px solid var(--comfy-menu-bg); + padding: 10px; + display: flex; + gap: 10px; +} +.comfy-group-manage footer button { + font-size: 14px; + padding: 5px 10px; + border-radius: 0; +} +.comfy-group-manage footer button:first-child { + margin-right: auto; +} diff --git a/web/extensions/core/groupNodeManage.js b/web/extensions/core/groupNodeManage.js new file mode 100644 index 000000000..1ab338386 --- /dev/null +++ b/web/extensions/core/groupNodeManage.js @@ -0,0 +1,422 @@ +import { $el, ComfyDialog } from "../../scripts/ui.js"; +import { DraggableList } from "../../scripts/ui/draggableList.js"; +import { addStylesheet } from "../../scripts/utils.js"; +import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; + +addStylesheet(import.meta.url); + +const ORDER = Symbol(); + +function merge(target, source) { + if (typeof target === "object" && typeof source === "object") { + for (const key in source) { + const sv = source[key]; + if (typeof sv === "object") { + let tv = target[key]; + if (!tv) tv = target[key] = {}; + merge(tv, source[key]); + } else { + target[key] = sv; + } + } + } + + return target; +} + +export class ManageGroupDialog extends ComfyDialog { + /** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */ + tabs = {}; + /** @type { number | null | undefined } */ + selectedNodeIndex; + /** @type { keyof ManageGroupDialog["tabs"] } */ + selectedTab = "Inputs"; + /** @type { string | undefined } */ + selectedGroup; + + /** @type { Record>> } */ + modifications = {}; + + get selectedNodeInnerIndex() { + return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex; + } + + constructor(app) { + super(); + this.app = app; + this.element = $el("dialog.comfy-group-manage", { + parent: document.body, + }); + } + + changeTab(tab) { + this.tabs[this.selectedTab].tab.classList.remove("active"); + this.tabs[this.selectedTab].page.classList.remove("active"); + this.tabs[tab].tab.classList.add("active"); + this.tabs[tab].page.classList.add("active"); + this.selectedTab = tab; + } + + changeNode(index, force) { + if (!force && this.selectedNodeIndex === index) return; + + if (this.selectedNodeIndex != null) { + this.nodeItems[this.selectedNodeIndex].classList.remove("selected"); + } + this.nodeItems[index].classList.add("selected"); + this.selectedNodeIndex = index; + + if (!this.buildInputsPage() && this.selectedTab === "Inputs") { + this.changeTab("Widgets"); + } + if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") { + this.changeTab("Outputs"); + } + if (!this.buildOutputsPage() && this.selectedTab === "Outputs") { + this.changeTab("Inputs"); + } + + this.changeTab(this.selectedTab); + } + + getGroupData() { + this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup]; + this.groupNodeDef = this.groupNodeType.nodeData; + this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType); + } + + changeGroup(group, reset = true) { + this.selectedGroup = group; + this.getGroupData(); + + const nodes = this.groupData.nodeData.nodes; + this.nodeItems = nodes.map((n, i) => + $el( + "li.draggable-item", + { + dataset: { + nodeindex: n.index + "", + }, + onclick: () => { + this.changeNode(i); + }, + }, + [ + $el("span.drag-handle"), + $el( + "div", + { + textContent: n.title ?? n.type, + }, + n.title + ? $el("span", { + textContent: n.type, + }) + : [] + ), + ] + ) + ); + + this.innerNodesList.replaceChildren(...this.nodeItems); + + if (reset) { + this.selectedNodeIndex = null; + this.changeNode(0); + } else { + const items = this.draggable.getAllItems(); + let index = items.findIndex(item => item.classList.contains("selected")); + if(index === -1) index = this.selectedNodeIndex; + this.changeNode(index, true); + } + + const ordered = [...nodes]; + this.draggable?.dispose(); + this.draggable = new DraggableList(this.innerNodesList, "li"); + this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => { + if (oldPosition === newPosition) return; + ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]); + for (let i = 0; i < ordered.length; i++) { + this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i }); + } + }); + } + + storeModification({ nodeIndex, section, prop, value }) { + const groupMod = (this.modifications[this.selectedGroup] ??= {}); + const nodesMod = (groupMod.nodes ??= {}); + const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {}); + const typeMod = (nodeMod[section] ??= {}); + if (typeof value === "object") { + const objMod = (typeMod[prop] ??= {}); + Object.assign(objMod, value); + } else { + typeMod[prop] = value; + } + } + + getEditElement(section, prop, value, placeholder, checked, checkable = true) { + if (value === placeholder) value = ""; + + const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop]; + if (mods) { + if (mods.name != null) { + value = mods.name; + } + if (mods.visible != null) { + checked = mods.visible; + } + } + + return $el("div", [ + $el("input", { + value, + placeholder, + type: "text", + onchange: (e) => { + this.storeModification({ section, prop, value: { name: e.target.value } }); + }, + }), + $el("label", { textContent: "Visible" }, [ + $el("input", { + type: "checkbox", + checked, + disabled: !checkable, + onchange: (e) => { + this.storeModification({ section, prop, value: { visible: !!e.target.checked } }); + }, + }), + ]), + ]); + } + + buildWidgetsPage() { + const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex]; + const items = Object.keys(widgets ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.widgetsPage.replaceChildren( + ...items.map((oldName) => { + return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false); + }) + ); + return !!items.length; + } + + buildInputsPage() { + const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex]; + const items = Object.keys(inputs ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.inputsPage.replaceChildren( + ...items + .map((oldName) => { + let value = inputs[oldName]; + if (!value) { + return; + } + + return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false); + }) + .filter(Boolean) + ); + return !!items.length; + } + + buildOutputsPage() { + const nodes = this.groupData.nodeData.nodes; + const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]); + const outputs = innerNodeDef?.output ?? []; + const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex]; + + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.output; + const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex]; + const checkable = node.type !== "PrimitiveNode"; + this.outputsPage.replaceChildren( + ...outputs + .map((type, slot) => { + const groupOutputIndex = groupOutputs?.[slot]; + const oldName = innerNodeDef.output_name?.[slot] ?? type; + let value = config?.[slot]?.name; + const visible = config?.[slot]?.visible || groupOutputIndex != null; + if (!value || value === oldName) { + value = ""; + } + return this.getEditElement("output", slot, value, oldName, visible, checkable); + }) + .filter(Boolean) + ); + return !!outputs.length; + } + + show(type) { + const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b)); + + this.innerNodesList = $el("ul.comfy-group-manage-list-items"); + this.widgetsPage = $el("section.comfy-group-manage-node-page"); + this.inputsPage = $el("section.comfy-group-manage-node-page"); + this.outputsPage = $el("section.comfy-group-manage-node-page"); + const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]); + + this.tabs = [ + ["Inputs", this.inputsPage], + ["Widgets", this.widgetsPage], + ["Outputs", this.outputsPage], + ].reduce((p, [name, page]) => { + p[name] = { + tab: $el("a", { + onclick: () => { + this.changeTab(name); + }, + textContent: name, + }), + page, + }; + return p; + }, {}); + + const outer = $el("div.comfy-group-manage-outer", [ + $el("header", [ + $el("h2", "Group Nodes"), + $el( + "select", + { + onchange: (e) => { + this.changeGroup(e.target.value); + }, + }, + groupNodes.map((g) => + $el("option", { + textContent: g, + selected: "workflow/" + g === type, + value: g, + }) + ) + ), + ]), + $el("main", [ + $el("section.comfy-group-manage-list", this.innerNodesList), + $el("section.comfy-group-manage-node", [ + $el( + "header", + Object.values(this.tabs).map((t) => t.tab) + ), + pages, + ]), + ]), + $el("footer", [ + $el( + "button.comfy-btn", + { + onclick: (e) => { + const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup); + if (node) { + alert("This group node is in use in the current workflow, please first remove these."); + return; + } + if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) { + delete app.graph.extra.groupNodes[this.selectedGroup]; + LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup); + } + this.show(); + }, + }, + "Delete Group Node" + ), + $el( + "button.comfy-btn", + { + onclick: async () => { + let nodesByType; + let recreateNodes = []; + const types = {}; + for (const g in this.modifications) { + const type = app.graph.extra.groupNodes[g]; + let config = (type.config ??= {}); + + let nodeMods = this.modifications[g]?.nodes; + if (nodeMods) { + const keys = Object.keys(nodeMods); + if (nodeMods[keys[0]][ORDER]) { + // If any node is reordered, they will all need sequencing + const orderedNodes = []; + const orderedMods = {}; + const orderedConfig = {}; + + for (const n of keys) { + const order = nodeMods[n][ORDER].order; + orderedNodes[order] = type.nodes[+n]; + orderedMods[order] = nodeMods[n]; + orderedNodes[order].index = order; + } + + // Rewrite links + for (const l of type.links) { + if (l[0] != null) l[0] = type.nodes[l[0]].index; + if (l[2] != null) l[2] = type.nodes[l[2]].index; + } + + // Rewrite externals + if (type.external) { + for (const ext of type.external) { + ext[0] = type.nodes[ext[0]]; + } + } + + // Rewrite modifications + for (const id of keys) { + if (config[id]) { + orderedConfig[type.nodes[id].index] = config[id]; + } + delete config[id]; + } + + type.nodes = orderedNodes; + nodeMods = orderedMods; + type.config = config = orderedConfig; + } + + merge(config, nodeMods); + } + + types[g] = type; + + if (!nodesByType) { + nodesByType = app.graph._nodes.reduce((p, n) => { + p[n.type] ??= []; + p[n.type].push(n); + return p; + }, {}); + } + + const nodes = nodesByType["workflow/" + g]; + if (nodes) recreateNodes.push(...nodes); + } + + await GroupNodeConfig.registerFromWorkflow(types, {}); + + for (const node of recreateNodes) { + node.recreate(); + } + + this.modifications = {}; + this.app.graph.setDirtyCanvas(true, true); + this.changeGroup(this.selectedGroup, false); + }, + }, + "Save" + ), + $el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"), + ]), + ]); + + this.element.replaceChildren(outer); + this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]); + this.element.showModal(); + + this.element.addEventListener("close", () => { + this.draggable?.dispose(); + }); + } +} \ No newline at end of file diff --git a/web/jsconfig.json b/web/jsconfig.json index 57403d8cf..b65fa2746 100644 --- a/web/jsconfig.json +++ b/web/jsconfig.json @@ -3,7 +3,8 @@ "baseUrl": ".", "paths": { "/*": ["./*"] - } + }, + "lib": ["DOM", "ES2022"] }, "include": ["."] } diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index eb0742d38..d5eeebdbd 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -239,7 +239,8 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { node.flags?.collapsed || (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.computedHeight <= 0 || - widget.type === "converted-widget"; + widget.type === "converted-widget"|| + widget.type === "hidden"; element.hidden = hidden; element.style.display = hidden ? "none" : null; if (hidden) { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 6887f7028..d07d69dc8 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -4,6 +4,19 @@ import { ComfySettingsDialog } from "./ui/settings.js"; export const ComfyDialog = _ComfyDialog; +/** + * + * @param { string } tag HTML Element Tag and optional classes e.g. div.class1.class2 + * @param { string | Element | Element[] | { + * parent?: Element, + * $?: (el: Element) => void, + * dataset?: DOMStringMap, + * style?: CSSStyleDeclaration, + * for?: string + * } | undefined } propsOrChildren + * @param { Element[] | undefined } [children] + * @returns + */ export function $el(tag, propsOrChildren, children) { const split = tag.split("."); const element = document.createElement(split.shift()); @@ -12,6 +25,11 @@ export function $el(tag, propsOrChildren, children) { } if (propsOrChildren) { + if (typeof propsOrChildren === "string") { + propsOrChildren = { textContent: propsOrChildren }; + } else if (propsOrChildren instanceof Element) { + propsOrChildren = [propsOrChildren]; + } if (Array.isArray(propsOrChildren)) { element.append(...propsOrChildren); } else { @@ -35,7 +53,7 @@ export function $el(tag, propsOrChildren, children) { Object.assign(element, propsOrChildren); if (children) { - element.append(...children); + element.append(...(children instanceof Array ? children : [children])); } if (parent) { diff --git a/web/scripts/ui/draggableList.js b/web/scripts/ui/draggableList.js new file mode 100644 index 000000000..d53594886 --- /dev/null +++ b/web/scripts/ui/draggableList.js @@ -0,0 +1,287 @@ +// @ts-check +/* + Original implementation: + https://github.com/TahaSh/drag-to-reorder + MIT License + + Copyright (c) 2023 Taha Shashtari + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +import { $el } from "../ui.js"; + +$el("style", { + parent: document.head, + textContent: ` + .draggable-item { + position: relative; + will-change: transform; + user-select: none; + } + .draggable-item.is-idle { + transition: 0.25s ease transform; + } + .draggable-item.is-draggable { + z-index: 10; + } + ` +}); + +export class DraggableList extends EventTarget { + listContainer; + draggableItem; + pointerStartX; + pointerStartY; + scrollYMax; + itemsGap = 0; + items = []; + itemSelector; + handleClass = "drag-handle"; + off = []; + offDrag = []; + + constructor(element, itemSelector) { + super(); + this.listContainer = element; + this.itemSelector = itemSelector; + + if (!this.listContainer) return; + + this.off.push(this.on(this.listContainer, "mousedown", this.dragStart)); + this.off.push(this.on(this.listContainer, "touchstart", this.dragStart)); + this.off.push(this.on(document, "mouseup", this.dragEnd)); + this.off.push(this.on(document, "touchend", this.dragEnd)); + } + + getAllItems() { + if (!this.items?.length) { + this.items = Array.from(this.listContainer.querySelectorAll(this.itemSelector)); + this.items.forEach((element) => { + element.classList.add("is-idle"); + }); + } + return this.items; + } + + getIdleItems() { + return this.getAllItems().filter((item) => item.classList.contains("is-idle")); + } + + isItemAbove(item) { + return item.hasAttribute("data-is-above"); + } + + isItemToggled(item) { + return item.hasAttribute("data-is-toggled"); + } + + on(source, event, listener, options) { + listener = listener.bind(this); + source.addEventListener(event, listener, options); + return () => source.removeEventListener(event, listener); + } + + dragStart(e) { + if (e.target.classList.contains(this.handleClass)) { + this.draggableItem = e.target.closest(this.itemSelector); + } + + if (!this.draggableItem) return; + + this.pointerStartX = e.clientX || e.touches[0].clientX; + this.pointerStartY = e.clientY || e.touches[0].clientY; + this.scrollYMax = this.listContainer.scrollHeight - this.listContainer.clientHeight; + + this.setItemsGap(); + this.initDraggableItem(); + this.initItemsState(); + + this.offDrag.push(this.on(document, "mousemove", this.drag)); + this.offDrag.push(this.on(document, "touchmove", this.drag, { passive: false })); + + this.dispatchEvent( + new CustomEvent("dragstart", { + detail: { element: this.draggableItem, position: this.getAllItems().indexOf(this.draggableItem) }, + }) + ); + } + + setItemsGap() { + if (this.getIdleItems().length <= 1) { + this.itemsGap = 0; + return; + } + + const item1 = this.getIdleItems()[0]; + const item2 = this.getIdleItems()[1]; + + const item1Rect = item1.getBoundingClientRect(); + const item2Rect = item2.getBoundingClientRect(); + + this.itemsGap = Math.abs(item1Rect.bottom - item2Rect.top); + } + + initItemsState() { + this.getIdleItems().forEach((item, i) => { + if (this.getAllItems().indexOf(this.draggableItem) > i) { + item.dataset.isAbove = ""; + } + }); + } + + initDraggableItem() { + this.draggableItem.classList.remove("is-idle"); + this.draggableItem.classList.add("is-draggable"); + } + + drag(e) { + if (!this.draggableItem) return; + + e.preventDefault(); + + const clientX = e.clientX || e.touches[0].clientX; + const clientY = e.clientY || e.touches[0].clientY; + + const listRect = this.listContainer.getBoundingClientRect(); + + if (clientY > listRect.bottom) { + if (this.listContainer.scrollTop < this.scrollYMax) { + this.listContainer.scrollBy(0, 10); + this.pointerStartY -= 10; + } + } else if (clientY < listRect.top && this.listContainer.scrollTop > 0) { + this.pointerStartY += 10; + this.listContainer.scrollBy(0, -10); + } + + const pointerOffsetX = clientX - this.pointerStartX; + const pointerOffsetY = clientY - this.pointerStartY; + + this.updateIdleItemsStateAndPosition(); + this.draggableItem.style.transform = `translate(${pointerOffsetX}px, ${pointerOffsetY}px)`; + } + + updateIdleItemsStateAndPosition() { + const draggableItemRect = this.draggableItem.getBoundingClientRect(); + const draggableItemY = draggableItemRect.top + draggableItemRect.height / 2; + + // Update state + this.getIdleItems().forEach((item) => { + const itemRect = item.getBoundingClientRect(); + const itemY = itemRect.top + itemRect.height / 2; + if (this.isItemAbove(item)) { + if (draggableItemY <= itemY) { + item.dataset.isToggled = ""; + } else { + delete item.dataset.isToggled; + } + } else { + if (draggableItemY >= itemY) { + item.dataset.isToggled = ""; + } else { + delete item.dataset.isToggled; + } + } + }); + + // Update position + this.getIdleItems().forEach((item) => { + if (this.isItemToggled(item)) { + const direction = this.isItemAbove(item) ? 1 : -1; + item.style.transform = `translateY(${direction * (draggableItemRect.height + this.itemsGap)}px)`; + } else { + item.style.transform = ""; + } + }); + } + + dragEnd() { + if (!this.draggableItem) return; + + this.applyNewItemsOrder(); + this.cleanup(); + } + + applyNewItemsOrder() { + const reorderedItems = []; + + let oldPosition = -1; + this.getAllItems().forEach((item, index) => { + if (item === this.draggableItem) { + oldPosition = index; + return; + } + if (!this.isItemToggled(item)) { + reorderedItems[index] = item; + return; + } + const newIndex = this.isItemAbove(item) ? index + 1 : index - 1; + reorderedItems[newIndex] = item; + }); + + for (let index = 0; index < this.getAllItems().length; index++) { + const item = reorderedItems[index]; + if (typeof item === "undefined") { + reorderedItems[index] = this.draggableItem; + } + } + + reorderedItems.forEach((item) => { + this.listContainer.appendChild(item); + }); + + this.items = reorderedItems; + + this.dispatchEvent( + new CustomEvent("dragend", { + detail: { element: this.draggableItem, oldPosition, newPosition: reorderedItems.indexOf(this.draggableItem) }, + }) + ); + } + + cleanup() { + this.itemsGap = 0; + this.items = []; + this.unsetDraggableItem(); + this.unsetItemState(); + + this.offDrag.forEach((f) => f()); + this.offDrag = []; + } + + unsetDraggableItem() { + this.draggableItem.style = null; + this.draggableItem.classList.remove("is-draggable"); + this.draggableItem.classList.add("is-idle"); + this.draggableItem = null; + } + + unsetItemState() { + this.getIdleItems().forEach((item, i) => { + delete item.dataset.isAbove; + delete item.dataset.isToggled; + item.style.transform = ""; + }); + } + + dispose() { + this.off.forEach((f) => f()); + } +} diff --git a/web/style.css b/web/style.css index c57fb7f73..5c1133495 100644 --- a/web/style.css +++ b/web/style.css @@ -145,6 +145,12 @@ body { } .comfy-menu span.drag-handle { + position: absolute; + top: 0; + left: 0; +} + +span.drag-handle { width: 10px; height: 20px; display: inline-block; @@ -160,12 +166,9 @@ body { letter-spacing: 2px; color: var(--drag-text); text-shadow: 1px 0 1px black; - position: absolute; - top: 0; - left: 0; } -.comfy-menu span.drag-handle::after { +span.drag-handle::after { content: '.. .. ..'; } From 9bddc9d94b478202c6ea3e8ff8007efefcc047c1 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 13 Jan 2024 21:02:51 +0000 Subject: [PATCH 29/54] Fix crash on group render --- web/extensions/core/groupNode.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 4cf1f7621..80b836b71 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -843,6 +843,7 @@ export class GroupNodeHandler { const r = onDrawForeground?.apply?.(this, arguments); if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) { const n = groupData.nodes[this.runningInternalNodeId]; + if(!n) return; const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`; ctx.save(); ctx.font = "12px sans-serif"; From 270daa02a8f94ad8600a1c3bf562c981fd02f416 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 14 Jan 2024 19:53:52 +0000 Subject: [PATCH 30/54] Adds copy image option if browser feature available (#2544) * Adds copy image option if browser feature available * refactor --- web/scripts/app.js | 118 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 96 insertions(+), 22 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index d131045d7..e6c010617 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -269,6 +269,71 @@ export class ComfyApp { * @param {*} node The node to add the menu handler */ #addNodeContextMenuHandler(node) { + function getCopyImageOption(img) { + if (typeof window.ClipboardItem === "undefined") return []; + return [ + { + content: "Copy Image", + callback: async () => { + const url = new URL(img.src); + url.searchParams.delete("preview"); + + const writeImage = async (blob) => { + await navigator.clipboard.write([ + new ClipboardItem({ + [blob.type]: blob, + }), + ]); + }; + + try { + const data = await fetch(url); + const blob = await data.blob(); + try { + await writeImage(blob); + } catch (error) { + // Chrome seems to only support PNG on write, convert and try again + if (blob.type !== "image/png") { + const canvas = $el("canvas", { + width: img.naturalWidth, + height: img.naturalHeight, + }); + const ctx = canvas.getContext("2d"); + let image; + if (typeof window.createImageBitmap === "undefined") { + image = new Image(); + const p = new Promise((resolve, reject) => { + image.onload = resolve; + image.onerror = reject; + }).finally(() => { + URL.revokeObjectURL(image.src); + }); + image.src = URL.createObjectURL(blob); + await p; + } else { + image = await createImageBitmap(blob); + } + try { + ctx.drawImage(image, 0, 0); + canvas.toBlob(writeImage, "image/png"); + } finally { + if (typeof image.close === "function") { + image.close(); + } + } + + return; + } + throw error; + } + } catch (error) { + alert("Error copying image: " + (error.message ?? error)); + } + }, + }, + ]; + } + node.prototype.getExtraMenuOptions = function (_, options) { if (this.imgs) { // If this node has images then we add an open in new tab item @@ -286,16 +351,17 @@ export class ComfyApp { content: "Open Image", callback: () => { let url = new URL(img.src); - url.searchParams.delete('preview'); - window.open(url, "_blank") + url.searchParams.delete("preview"); + window.open(url, "_blank"); }, }, + ...getCopyImageOption(img), { content: "Save Image", callback: () => { const a = document.createElement("a"); let url = new URL(img.src); - url.searchParams.delete('preview'); + url.searchParams.delete("preview"); a.href = url; a.setAttribute("download", new URLSearchParams(url.search).get("filename")); document.body.append(a); @@ -308,33 +374,41 @@ export class ComfyApp { } options.push({ - content: "Bypass", - callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); } - }); + content: "Bypass", + callback: (obj) => { + if (this.mode === 4) this.mode = 0; + else this.mode = 4; + this.graph.change(); + }, + }); // prevent conflict of clipspace content - if(!ComfyApp.clipspace_return_node) { + if (!ComfyApp.clipspace_return_node) { options.push({ - content: "Copy (Clipspace)", - callback: (obj) => { ComfyApp.copyToClipspace(this); } - }); + content: "Copy (Clipspace)", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + }, + }); - if(ComfyApp.clipspace != null) { + if (ComfyApp.clipspace != null) { options.push({ - content: "Paste (Clipspace)", - callback: () => { ComfyApp.pasteFromClipspace(this); } - }); + content: "Paste (Clipspace)", + callback: () => { + ComfyApp.pasteFromClipspace(this); + }, + }); } - if(ComfyApp.isImageNode(this)) { + if (ComfyApp.isImageNode(this)) { options.push({ - content: "Open in MaskEditor", - callback: (obj) => { - ComfyApp.copyToClipspace(this); - ComfyApp.clipspace_return_node = this; - ComfyApp.open_maskeditor(); - } - }); + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); + }, + }); } } }; From 2395ae740a4dd45534e8ba21031deaefdcade1b4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 Jan 2024 17:25:21 -0500 Subject: [PATCH 31/54] Make unclip more deterministic. Pass a seed argument note that this might make old unclip images different. --- comfy/ldm/modules/encoders/noise_aug_modules.py | 4 ++-- comfy/model_base.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index 66767b587..a5d866030 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -23,13 +23,13 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) x = self.scale(x) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) z = self.unscale(z) noise_level = self.time_embed(noise_level) return z, noise_level diff --git a/comfy/model_base.py b/comfy/model_base.py index 52c87ede6..b2ea6590f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -210,7 +210,7 @@ class BaseModel(torch.nn.Module): return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) -def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] noise_aug = [] @@ -219,7 +219,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge weight = unclip_cond["strength"] noise_augment = unclip_cond["noise_augmentation"] noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) noise_aug.append(noise_augment) @@ -245,11 +245,11 @@ class SD21UNCLIP(BaseModel): if unclip_conditioning is None: return torch.zeros((1, self.adm_channels)) else: - return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: - return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] + return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: return args["pooled_output"] From 1dab412c79eccdd055265475a0bd31ba97e2cf9c Mon Sep 17 00:00:00 2001 From: TFWol <9045213+TFWol@users.noreply.github.com> Date: Sun, 14 Jan 2024 15:06:33 -0800 Subject: [PATCH 32/54] Add error handling to initial fix to keep cache intact --- folder_paths.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 4abed783b..ef9b8ccfa 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -139,6 +139,13 @@ def recursive_search(directory, excluded_dir_names=None): result = [] dirs = {} + + # Attempt to add the initial directory to dirs with error handling + try: + dirs[directory] = os.path.getmtime(directory) + except FileNotFoundError: + print(f"Warning: Unable to access {directory}. Skipping this path.") + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: From f9e55d8463da692954d84f51ca354161396fe1b8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 Jan 2024 03:10:22 -0500 Subject: [PATCH 33/54] Only auto enable bf16 VAE on nvidia GPUs that actually support it. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index fefd3c8c9..e12146d11 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -175,7 +175,7 @@ try: if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported(): + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: VAE_DTYPE = torch.bfloat16 if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: From 23687da9a94f862b0e372db3c61412c82b95f398 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 15 Jan 2024 17:45:48 +0000 Subject: [PATCH 34/54] Fix logging not checking onChange --- web/scripts/logging.js | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/web/scripts/logging.js b/web/scripts/logging.js index c73462e1e..875dd970b 100644 --- a/web/scripts/logging.js +++ b/web/scripts/logging.js @@ -269,6 +269,9 @@ export class ComfyLogging { id: settingId, name: settingId, defaultValue: true, + onChange: (value) => { + this.enabled = value; + }, type: (name, setter, value) => { return $el("tr", [ $el("td", [ @@ -283,7 +286,7 @@ export class ComfyLogging { type: "checkbox", checked: value, onchange: (event) => { - setter((this.enabled = event.target.checked)); + setter(event.target.checked); }, }), $el("button", { From 93bbe3f4c0cac7165a523185c7f373893f3292cb Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 16 Jan 2024 13:27:40 +0000 Subject: [PATCH 35/54] Auto queue on change (#2542) * Add toggle to enable auto queue when graph is changed * type fix * better * better alignment * Change undoredo to not ignore inputs when autoqueue in change mode --- web/extensions/core/undoRedo.js | 31 +++++++++++++---- web/scripts/app.js | 1 + web/scripts/ui.js | 42 ++++++++++++++++++++--- web/scripts/ui/toggleSwitch.js | 60 +++++++++++++++++++++++++++++++++ web/style.css | 38 +++++++++++++++++++++ 5 files changed, 161 insertions(+), 11 deletions(-) create mode 100644 web/scripts/ui/toggleSwitch.js diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index ff976c74c..900eed2a7 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js" const MAX_HISTORY = 50; @@ -15,6 +16,7 @@ function checkState() { } activeState = clone(currentState); redo.length = 0; + api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState })); } } @@ -92,7 +94,7 @@ const undoRedo = async (e) => { }; const bindInput = (activeEl) => { - if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") { for (const evt of ["change", "input", "blur"]) { if (`on${evt}` in activeEl) { const listener = () => { @@ -111,12 +113,16 @@ window.addEventListener( "keydown", (e) => { requestAnimationFrame(async () => { - const activeEl = document.activeElement; - if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { - // Ignore events on inputs, they have their native history - return; + let activeEl; + // If we are auto queue in change mode then we do want to trigger on inputs + if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") { + activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } } - + keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta"; if (keyIgnored) return; @@ -143,6 +149,11 @@ window.addEventListener("mouseup", () => { checkState(); }); +// Handle prompt queue event for dynamic widget changes +api.addEventListener("promptQueued", () => { + checkState(); +}); + // Handle litegraph clicks const processMouseUp = LGraphCanvas.prototype.processMouseUp; LGraphCanvas.prototype.processMouseUp = function (e) { @@ -156,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) { checkState(); return v; }; + +// Handle litegraph context menu for COMBO widgets +const close = LiteGraph.ContextMenu.prototype.close; +LiteGraph.ContextMenu.prototype.close = function(e) { + const v = close.apply(this, arguments); + checkState(); + return v; +} \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index e6c010617..a2cd1e316 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2068,6 +2068,7 @@ export class ComfyApp { } finally { this.#processingQueue = false; } + api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } })); } /** diff --git a/web/scripts/ui.js b/web/scripts/ui.js index d07d69dc8..4437345a3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -1,5 +1,6 @@ import { api } from "./api.js"; import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js"; +import { toggleSwitch } from "./ui/toggleSwitch.js"; import { ComfySettingsDialog } from "./ui/settings.js"; export const ComfyDialog = _ComfyDialog; @@ -368,6 +369,31 @@ export class ComfyUI { }, }); + const autoQueueModeEl = toggleSwitch( + "autoQueueMode", + [ + { text: "instant", tooltip: "A new prompt will be queued as soon as the queue reaches 0" }, + { text: "change", tooltip: "A new prompt will be queued when the queue is at 0 and the graph is/has changed" }, + ], + { + onChange: (value) => { + this.autoQueueMode = value.item.value; + }, + } + ); + autoQueueModeEl.style.display = "none"; + + api.addEventListener("graphChanged", () => { + if (this.autoQueueMode === "change") { + if (this.lastQueueSize === 0) { + this.graphHasChanged = false; + app.queuePrompt(0, this.batchCount); + } else { + this.graphHasChanged = true; + } + } + }); + this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [ $el("div.drag-handle", { style: { @@ -394,6 +420,7 @@ export class ComfyUI { document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none"; this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1; document.getElementById("autoQueueCheckbox").checked = false; + this.autoQueueEnabled = false; }, }), ]), @@ -425,20 +452,22 @@ export class ComfyUI { }, }), ]), - $el("div",[ $el("label",{ for:"autoQueueCheckbox", innerHTML: "Auto Queue" - // textContent: "Auto Queue" }), $el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "Automatically queue prompt when the queue size hits 0", - + onchange: (e) => { + this.autoQueueEnabled = e.target.checked; + autoQueueModeEl.style.display = this.autoQueueEnabled ? "" : "none"; + } }), + autoQueueModeEl ]) ]), $el("div.comfy-menu-btns", [ @@ -572,10 +601,13 @@ export class ComfyUI { if ( this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && - document.getElementById("autoQueueCheckbox").checked && - ! app.lastExecutionError + this.autoQueueEnabled && + (this.autoQueueMode === "instant" || this.graphHasChanged) && + !app.lastExecutionError ) { app.queuePrompt(0, this.batchCount); + status.exec_info.queue_remaining += this.batchCount; + this.graphHasChanged = false; } this.lastQueueSize = status.exec_info.queue_remaining; } diff --git a/web/scripts/ui/toggleSwitch.js b/web/scripts/ui/toggleSwitch.js new file mode 100644 index 000000000..59597ef90 --- /dev/null +++ b/web/scripts/ui/toggleSwitch.js @@ -0,0 +1,60 @@ +import { $el } from "../ui.js"; + +/** + * @typedef { { text: string, value?: string, tooltip?: string } } ToggleSwitchItem + */ +/** + * Creates a toggle switch element + * @param { string } name + * @param { Array void } [opts.onChange] + */ +export function toggleSwitch(name, items, { onChange } = {}) { + let selectedIndex; + let elements; + + function updateSelected(index) { + if (selectedIndex != null) { + elements[selectedIndex].classList.remove("comfy-toggle-selected"); + } + onChange?.({ item: items[index], prev: selectedIndex == null ? undefined : items[selectedIndex] }); + selectedIndex = index; + elements[selectedIndex].classList.add("comfy-toggle-selected"); + } + + elements = items.map((item, i) => { + if (typeof item === "string") item = { text: item }; + if (!item.value) item.value = item.text; + + const toggle = $el( + "label", + { + textContent: item.text, + title: item.tooltip ?? "", + }, + $el("input", { + name, + type: "radio", + value: item.value ?? item.text, + checked: item.selected, + onchange: () => { + updateSelected(i); + }, + }) + ); + if (item.selected) { + updateSelected(i); + } + return toggle; + }); + + const container = $el("div.comfy-toggle-switch", elements); + + if (selectedIndex == null) { + elements[0].children[0].checked = true; + updateSelected(0); + } + + return container; +} diff --git a/web/style.css b/web/style.css index 5c1133495..44ee60198 100644 --- a/web/style.css +++ b/web/style.css @@ -121,6 +121,7 @@ body { width: 100%; } +.comfy-toggle-switch, .comfy-btn, .comfy-menu > button, .comfy-menu-btns button, @@ -434,6 +435,43 @@ dialog::backdrop { margin-left: 5px; } +.comfy-toggle-switch { + border-width: 2px; + display: flex; + background-color: var(--comfy-input-bg); + margin: 2px 0; + white-space: nowrap; +} + +.comfy-toggle-switch label { + padding: 2px 0px 3px 6px; + flex: auto; + border-radius: 8px; + align-items: center; + display: flex; + justify-content: center; +} + +.comfy-toggle-switch label:first-child { + border-top-left-radius: 8px; + border-bottom-left-radius: 8px; +} +.comfy-toggle-switch label:last-child { + border-top-right-radius: 8px; + border-bottom-right-radius: 8px; +} + +.comfy-toggle-switch .comfy-toggle-selected { + background-color: var(--comfy-menu-bg); +} + +#extraOptions { + padding: 4px; + background-color: var(--bg-color); + margin-bottom: 4px; + border-radius: 4px; +} + /* Search box */ .litegraph.litesearchbox { From ee2c5fa72d4fc1714576fac7ba64aa5d607303d0 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 16 Jan 2024 13:58:54 +0000 Subject: [PATCH 36/54] Fix renaming upload widget (#2554) * Fix renaming upload widget * Allow custom name --- web/extensions/core/groupNode.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 18b0e6aa6..335bddbb7 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -349,7 +349,7 @@ export class GroupNodeConfig { } if (config[0] === "IMAGEUPLOAD") { if (!extra) extra = {}; - extra.widget = `${prefix}${config[1]?.widget ?? "image"}`; + extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image"; } if (extra) { From fad02dc2df0b9c4a0ab84dd1c559197cba752449 Mon Sep 17 00:00:00 2001 From: realazthat Date: Wed, 17 Jan 2024 17:16:34 -0500 Subject: [PATCH 37/54] Don't use PEP 604 type hints, to stay compatible with Python<3.10. --- execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/execution.py b/execution.py index e91e9a410..bc5cfe55c 100644 --- a/execution.py +++ b/execution.py @@ -751,7 +751,7 @@ class PromptQueue: if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - status_dict: dict|None = None + status_dict: Optional[dict] = None if status is not None: status_dict = copy.deepcopy(status._asdict()) From d76a04b6ea61306349861a7c4657567507385947 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 Jan 2024 19:37:19 -0500 Subject: [PATCH 38/54] Add unfinished ImageOnlyCheckpointSave node to save a SVD checkpoint. This node is unfinished, SVD checkpoints saved with this node will work with ComfyUI but not with anything else. --- comfy/clip_vision.py | 9 +++- comfy/model_base.py | 21 +++++--- comfy/sd.py | 13 +++-- comfy/supported_models_base.py | 6 +++ comfy_extras/nodes_model_merging.py | 83 +++++++++++++++-------------- comfy_extras/nodes_video_model.py | 17 ++++++ 6 files changed, 99 insertions(+), 50 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 4564fcfb2..200e1c6ef 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,4 +1,4 @@ -from .utils import load_torch_file, transformers_convert, common_upscale +from .utils import load_torch_file, transformers_convert, common_upscale, state_dict_prefix_replace import os import torch import contextlib @@ -41,9 +41,13 @@ class ClipVisionModel(): self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) + def get_sd(self): + return self.model.state_dict() + def encode_image(self, image): comfy.model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device)).float() @@ -76,6 +80,9 @@ def convert_to_transformers(sd, prefix): sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd = transformers_convert(sd, prefix, "vision_model.", 48) + else: + replace_prefix = {prefix: ""} + sd = state_dict_prefix_replace(sd, replace_prefix) return sd def load_clipvision_from_sd(sd, prefix="", convert_keys=False): diff --git a/comfy/model_base.py b/comfy/model_base.py index b2ea6590f..847687be4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -179,19 +179,28 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict, vae_state_dict): - clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + extra_sds = [] + if clip_state_dict is not None: + extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) + if vae_state_dict is not None: + extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) + if clip_vision_state_dict is not None: + extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: - clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) - vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds) if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) - return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + for sd in extra_sds: + unet_state_dict.update(sd) + + return unet_state_dict def set_inpaint(self): self.inpaint_model = True diff --git a/comfy/sd.py b/comfy/sd.py index 1ff25bec6..f49e87b15 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -534,7 +534,14 @@ def load_unet(unet_path): raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip, vae, metadata=None): - model_management.load_models_gpu([model, clip.load_model()]) - sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): + clip_sd = None + load_models = [model] + if clip is not None: + load_models.append(clip.load_model()) + clip_sd = clip.get_sd() + + model_management.load_models_gpu(load_models) + clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None + sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 49087d23e..5baf4bca6 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -65,6 +65,12 @@ class BASE: replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def process_clip_vision_state_dict_for_saving(self, state_dict): + replace_prefix = {} + if self.clip_vision_prefix is not None: + replace_prefix[""] = self.clip_vision_prefix + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model.diffusion_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index dad1dd637..d594cf490 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -119,6 +119,48 @@ class ModelMergeBlocks: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + + enable_modelspec = True + if isinstance(model.model, comfy.model_base.SDXL): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, comfy.model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + if model.model.model_type == comfy.model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + class CheckpointSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -137,46 +179,7 @@ class CheckpointSave: CATEGORY = "advanced/model_merging" def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - prompt_info = "" - if prompt is not None: - prompt_info = json.dumps(prompt) - - metadata = {} - - enable_modelspec = True - if isinstance(model.model, comfy.model_base.SDXL): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" - elif isinstance(model.model, comfy.model_base.SDXLRefiner): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" - else: - enable_modelspec = False - - if enable_modelspec: - metadata["modelspec.sai_model_spec"] = "1.0.0" - metadata["modelspec.implementation"] = "sgm" - metadata["modelspec.title"] = "{} {}".format(filename, counter) - - #TODO: - # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", - # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", - # "v2-inpainting" - - if model.model.model_type == comfy.model_base.ModelType.EPS: - metadata["modelspec.predict_key"] = "epsilon" - elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: - metadata["modelspec.predict_key"] = "v" - - if not args.disable_metadata: - metadata["prompt"] = prompt_info - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - output_checkpoint = f"{filename}_{counter:05}_.safetensors" - output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - - comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} class CLIPSave: diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 26a717a38..a52625652 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -3,6 +3,7 @@ import torch import comfy.utils import comfy.sd import folder_paths +import comfy_extras.nodes_model_merging class ImageOnlyCheckpointLoader: @@ -78,10 +79,26 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): + CATEGORY = "_for_testing" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip_vision": ("CLIP_VISION",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + + def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): + comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } NODE_DISPLAY_NAME_MAPPINGS = { From 9fff3c46b44c7dcf9769cb702c7fab9d5ba87e77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 18 Jan 2024 15:57:35 -0500 Subject: [PATCH 39/54] Move some nodes to model_patches section. --- comfy_extras/nodes_freelunch.py | 4 ++-- comfy_extras/nodes_hypertile.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 7512b841d..7764aa0b0 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -34,7 +34,7 @@ class FreeU: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] @@ -73,7 +73,7 @@ class FreeU_V2: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 544ae1b22..ae55d23dd 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -32,7 +32,7 @@ class HyperTile: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, tile_size, swap_size, max_depth, scale_depth): model_channels = model.model.model_config.unet_config["model_channels"] From 78a70fda877acab8e094d11697de51e979c54bc5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 Jan 2024 15:38:05 -0500 Subject: [PATCH 40/54] Remove useless import. --- comfy/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index f6f85de60..f674b47f7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,5 +1,4 @@ import torch -from contextlib import contextmanager import comfy.model_management def cast_bias_weight(s, input): From 5823f18a79d46446e07f418b9ede1939e1499ecc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 Jan 2024 23:08:15 -0500 Subject: [PATCH 41/54] Fix for the extracting issue on windows. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 167214c05..093873921 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ There is a portable standalone build for Windows that should work for running on Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +If you have trouble extracting it, right click the file -> properties -> unblock + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. From 45bf88d8ef90f3ce77d6fcb8a86a3b0e45f4dd16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristjan=20P=C3=A4rt?= Date: Mon, 22 Jan 2024 04:34:39 +0200 Subject: [PATCH 42/54] Fix queue on change to respect auto queue checkbox (#2608) * Fix render on change not respecting auto queue checkbox Fix issue where autoQueueEnabled checkbox is ignored for changes if autoQueueMode is left on `change` * Make check more specific --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 4437345a3..d4835c6e4 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -384,7 +384,7 @@ export class ComfyUI { autoQueueModeEl.style.display = "none"; api.addEventListener("graphChanged", () => { - if (this.autoQueueMode === "change") { + if (this.autoQueueMode === "change" && this.autoQueueEnabled === true) { if (this.lastQueueSize === 0) { this.graphHasChanged = false; app.queuePrompt(0, this.batchCount); From 4871a36458e7cd4af1a7f46dd6738c406e831413 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 Jan 2024 21:51:22 -0500 Subject: [PATCH 43/54] Cleanup some unused imports. --- comfy/clip_vision.py | 3 +-- comfy/conds.py | 1 - comfy/controlnet.py | 1 - comfy/diffusers_load.py | 1 - comfy/gligen.py | 2 +- comfy/model_base.py | 1 - comfy/sd.py | 3 --- comfy/sd1_clip.py | 1 - 8 files changed, 2 insertions(+), 11 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 200e1c6ef..8c77ee7a9 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,7 +1,6 @@ -from .utils import load_torch_file, transformers_convert, common_upscale, state_dict_prefix_replace +from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os import torch -import contextlib import json import comfy.ops diff --git a/comfy/conds.py b/comfy/conds.py index 6cff25184..23fa48872 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,4 +1,3 @@ -import enum import torch import math import comfy.utils diff --git a/comfy/controlnet.py b/comfy/controlnet.py index df474201c..82170431e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,7 +1,6 @@ import torch import math import os -import contextlib import comfy.utils import comfy.model_management import comfy.model_detection diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index c0b420e79..98b888a19 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -1,4 +1,3 @@ -import json import os import comfy.sd diff --git a/comfy/gligen.py b/comfy/gligen.py index 8d182839e..71892dfb1 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,5 +1,5 @@ import torch -from torch import nn, einsum +from torch import nn from .ldm.modules.attention import CrossAttention from inspect import isfunction diff --git a/comfy/model_base.py b/comfy/model_base.py index 847687be4..8a843a98c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -6,7 +6,6 @@ import comfy.model_management import comfy.conds import comfy.ops from enum import Enum -import contextlib from . import utils class ModelType(Enum): diff --git a/comfy/sd.py b/comfy/sd.py index f49e87b15..9ca9d1d12 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,9 +1,6 @@ import torch -import contextlib -import math from comfy import model_management -from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine import yaml diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6ffef515e..65ea909fe 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -6,7 +6,6 @@ import torch import traceback import zipfile from . import model_management -import contextlib import comfy.clip_model import json From f2d432f9a754f2bacb85a70f7e2b817e6dba44f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 Jan 2024 00:28:13 -0500 Subject: [PATCH 44/54] Fix potential turbo scheduler model patching issue. --- comfy_extras/nodes_custom_sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f02cb5ef7..99f9ea7dc 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -105,9 +105,8 @@ class SDTurboScheduler: def get_sigmas(self, model, steps, denoise): start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] - inner_model = model.patch_model(patch_weights=False) - sigmas = inner_model.model_sampling.sigma(timesteps) - model.unpatch_model() + comfy.model_management.load_models_gpu([model]) + sigmas = model.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) From 05cd00695a84cebd5603a31f665eb7301fba2beb Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 23 Jan 2024 17:47:01 +0900 Subject: [PATCH 45/54] typo fix - calculate_sigmas_scheduler (#2619) self.scheduler -> scheduler_name Co-authored-by: Lt.Dr.Data --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 89d8d4f28..66a3feb10 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -639,7 +639,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model, steps, sgm=True) else: - print("error invalid scheduler", self.scheduler) + print("error invalid scheduler", scheduler_name) return sigmas def sampler_object(name): From 3762e676a93fcfb3ebdf6baf845971af72d4daf8 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 23 Jan 2024 19:15:52 +0000 Subject: [PATCH 46/54] Support refresh on group node combos (#2625) * Support refresh on group node combos * fix check --- web/extensions/core/groupNode.js | 28 ++++++++++++++++++++++++++++ web/scripts/app.js | 2 ++ 2 files changed, 30 insertions(+) diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 335bddbb7..0f041fcd2 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -966,6 +966,26 @@ export class GroupNodeHandler { api.removeEventListener("executing", executing); api.removeEventListener("executed", executed); }; + + this.node.refreshComboInNode = (defs) => { + // Update combo widget options + for (const widgetName in this.groupData.newToOldWidgetMap) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget?.type === "combo") { + const old = this.groupData.newToOldWidgetMap[widgetName]; + const def = defs[old.node.type]; + const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName]; + if (!input) continue; + + widget.options.values = input[0]; + + if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) { + widget.value = widget.options.values[0]; + widget.callback(widget.value); + } + } + } + }; } updateInnerWidgets() { @@ -1245,6 +1265,14 @@ const ext = { node[GROUP] = new GroupNodeHandler(node); } }, + async refreshComboInNodes(defs) { + // Re-register group nodes so new ones are created with the correct options + Object.assign(globalDefs, defs); + const nodes = app.graph.extra?.groupNodes; + if (nodes) { + await GroupNodeConfig.registerFromWorkflow(nodes, {}); + } + } }; app.registerExtension(ext); diff --git a/web/scripts/app.js b/web/scripts/app.js index a2cd1e316..6df393ba6 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2212,6 +2212,8 @@ export class ComfyApp { } } } + + await this.#invokeExtensionsAsync("refreshComboInNodes", defs); } /** From b9911dcb2f48273479871d018823e11d03642d92 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 Jan 2024 20:01:37 -0500 Subject: [PATCH 47/54] Sync litegraph with repo. https://github.com/comfyanonymous/litegraph.js/pull/4 --- web/lib/litegraph.core.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 434c4a83b..080e0ef47 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -11496,7 +11496,7 @@ LGraphNode.prototype.executeAction = function(action) } timeout_close = setTimeout(function() { dialog.close(); - }, 500); + }, typeof options.hide_on_mouse_leave === "number" ? options.hide_on_mouse_leave : 500); }); // if filtering, check focus changed to comboboxes and prevent closing if (options.do_type_filter){ From d1533d9c0f1dde192f738ef1b745b15f49f41e02 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 24 Jan 2024 09:49:57 -0500 Subject: [PATCH 48/54] Add experimental photomaker nodes. Put the model file in models/photomaker and use PhotoMakerLoader. Then use PhotoMakerEncode with the keyword "photomaker" to apply the image --- comfy_extras/nodes_photomaker.py | 187 +++++++++++++++++++ folder_paths.py | 2 + models/photomaker/put_photomaker_models_here | 0 nodes.py | 1 + 4 files changed, 190 insertions(+) create mode 100644 comfy_extras/nodes_photomaker.py create mode 100644 models/photomaker/put_photomaker_models_here diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py new file mode 100644 index 000000000..90130142b --- /dev/null +++ b/comfy_extras/nodes_photomaker.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import folder_paths +import comfy.clip_model +import comfy.clip_vision +import comfy.ops + +# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "image_size": 224, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "hidden_act": "quick_gelu", +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = operations.LayerNorm(in_dim) + self.fc1 = operations.Linear(in_dim, hidden_dim) + self.fc2 = operations.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim, operations): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) + self.layer_norm = operations.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): + def __init__(self): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast) + self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048, comfy.ops.manual_cast) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[2] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + + RETURN_TYPES = ("PHOTOMAKER",) + FUNCTION = "load_photomaker_model" + + CATEGORY = "_for_testing/photomaker" + + def load_photomaker_model(self, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return (photomaker_model,) + + +class PhotoMakerEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker": ("PHOTOMAKER",), + "image": ("IMAGE",), + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_photomaker" + + CATEGORY = "_for_testing/photomaker" + + def apply_photomaker(self, photomaker, image, clip, text): + special_token = "photomaker" + pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() + try: + index = text.split(" ").index(special_token) + 1 + except ValueError: + index = -1 + tokens = clip.tokenize(text, return_word_ids=True) + out_tokens = {} + for k in tokens: + out_tokens[k] = [] + for t in tokens[k]: + f = list(filter(lambda x: x[2] != index, t)) + while len(f) < len(t): + f.append(t[-1]) + out_tokens[k].append(f) + + cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + + if index > 0: + token_index = index - 1 + num_id_images = 1 + class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] + out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) + else: + out = cond + + return ([[out, {"pooled_output": pooled}]], ) + + +NODE_CLASS_MAPPINGS = { + "PhotoMakerLoader": PhotoMakerLoader, + "PhotoMakerEncode": PhotoMakerEncode, +} + diff --git a/folder_paths.py b/folder_paths.py index ef9b8ccfa..f1bf40f8c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -29,6 +29,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) +folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) + folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") diff --git a/models/photomaker/put_photomaker_models_here b/models/photomaker/put_photomaker_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 6c7317b69..4ad35f79b 100644 --- a/nodes.py +++ b/nodes.py @@ -1943,6 +1943,7 @@ def init_custom_nodes(): "nodes_perpneg.py", "nodes_stable3d.py", "nodes_sdupscale.py", + "nodes_photomaker.py", ] for node_file in extras_files: From 89507f8adff4aff4507b6f35a67717badaecd4ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 Jan 2024 21:09:51 -0500 Subject: [PATCH 49/54] Remove some unused imports. --- comfy/ldm/modules/attention.py | 3 --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 3 --- comfy/samplers.py | 4 ---- 3 files changed, 10 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index fd8888d0e..9c9cb761d 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,12 +1,9 @@ -from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from functools import partial - from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index ea936e066..998afd977 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -1,12 +1,9 @@ from abc import abstractmethod -import math -import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from functools import partial from .util import ( checkpoint, diff --git a/comfy/samplers.py b/comfy/samplers.py index 66a3feb10..f4c3e268f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,13 +1,9 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch -import enum import collections from comfy import model_management import math -from comfy import model_base -import comfy.utils -import comfy.conds def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) From 2d105066df97a283fa155e39e0cc34ebbe58f55f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 Jan 2024 21:31:13 -0500 Subject: [PATCH 50/54] Cleanups. --- execution.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/execution.py b/execution.py index bc5cfe55c..00908eadd 100644 --- a/execution.py +++ b/execution.py @@ -1,12 +1,9 @@ -import os import sys import copy -import json import logging import threading import heapq import traceback -import gc import inspect from typing import List, Literal, NamedTuple, Optional From fc196aac80fd4bf6c8a39d85d1e809902871cade Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 Jan 2024 23:13:02 -0500 Subject: [PATCH 51/54] Add a LatentBatchSeedBehavior node. This lets you set it so the latents can use the same seed for the sampling on every image in the batch. --- comfy_extras/nodes_latent.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2eefc4c55..b7fd8cd68 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -122,10 +122,34 @@ class LatentBatch: samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) return (samples_out,) +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, } From 7f4725f6b3f72dd8bdb60dae5dd2c3e943263bcf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 27 Jan 2024 02:51:27 -0500 Subject: [PATCH 52/54] Fix some issues with --gpu-only --- comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_stable3d.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 71660f8a5..cb5c7d228 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -33,6 +33,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index e02a9875a..4375d8f96 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -46,7 +46,7 @@ class StableZero123_Conditioning: encode_pixels = pixels[:,:,:,:3] t = vae.encode(encode_pixels) cam_embeds = camera_embeddings(elevation, azimuth) - cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1) positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] From 079dbf919874e6fce170d316e409366bd409cfb9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 Jan 2024 19:36:32 -0500 Subject: [PATCH 53/54] Remove useless code. --- server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server.py b/server.py index 1135fff88..dca06f6fc 100644 --- a/server.py +++ b/server.py @@ -632,8 +632,6 @@ class PromptServer(): site = web.TCPSite(runner, address, port) await site.start() - if address == '': - address = '0.0.0.0' if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) From 9321198da6c191fab0ab3ea39e626c9e3d9053d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 Jan 2024 00:24:53 -0500 Subject: [PATCH 54/54] Add node to set only the conditioning area strength. --- nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nodes.py b/nodes.py index 4ad35f79b..fe38be9df 100644 --- a/nodes.py +++ b/nodes.py @@ -184,6 +184,26 @@ class ConditioningSetAreaPercentage: c.append(n) return (c, ) +class ConditioningSetAreaStrength: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, strength): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['strength'] = strength + c.append(n) + return (c, ) + + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): @@ -1754,6 +1774,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, + "ConditioningSetAreaStrength": ConditioningSetAreaStrength, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask,