diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 1eafcd0aa..6a13b80c9 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -1,4 +1,4 @@ - +#Taken from: https://github.com/dbolya/tomesd import torch from typing import Tuple, Callable @@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None): return x +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + else: + return torch.gather(input, dim, index) + + def bipartite_soft_matching_random2d(metric: torch.Tensor, w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False) -> Tuple[Callable, Callable]: """ Partitions the tokens into src and dst and merges r tokens from src to dst. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. - Args: - metric [B, N, C]: metric to use for similarity - w: image width in tokens @@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, if r <= 0: return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather with torch.no_grad(): hsy, wsx = h // sy, w // sx # For each sy by sx kernel, randomly assign one token to be dst and the rest src - idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device) - if no_rand: - rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) + rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) else: - rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) + rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) - idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) - idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) - rand_idx = idx_buffer.argsort(dim=1) + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) - num_dst = int((1 / (sx*sy)) * N) + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx a_idx = rand_idx[:, num_dst:, :] # src b_idx = rand_idx[:, :num_dst, :] # dst def split(x): C = x.shape[-1] - src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) - dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) return src, dst + # Cosine similarity between A and B metric = metric / metric.norm(dim=-1, keepdim=True) a, b = split(metric) scores = a @ b.transpose(-1, -2) @@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, # Can't reduce more than the # tokens in src r = min(a.shape[1], r) + # Find the most similar greedily node_max, node_idx = scores.max(dim=-1) edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] unm_idx = edge_idx[..., r:, :] # Unmerged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens - dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) n, t1, c = src.shape - unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) - src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) return torch.cat([unm, dst], dim=1) @@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] _, _, c = unm.shape - src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) # Combine back to the original shape out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) - out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) - out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) return out @@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, def get_functions(x, ratio, original_shape): b, c, original_h, original_w = original_shape original_tokens = original_h * original_w - downsample = int(math.sqrt(original_tokens // x.shape[1])) + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) stride_x = 2 stride_y = 2 max_downsample = 1 if downsample <= max_downsample: - w = original_w // downsample - h = original_h // downsample + w = int(math.ceil(original_w / downsample)) + h = int(math.ceil(original_h / downsample)) r = int(x.shape[1] * ratio) no_rand = False m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) diff --git a/comfy/samplers.py b/comfy/samplers.py index ddec99007..59dbab53d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -348,17 +348,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device): if 'adm' in x[1]: adm_inputs = [] weights = [] + noise_aug = [] adm_in = x[1]["adm"] for adm_c in adm_in: adm_cond = adm_c[0].image_embeds weight = adm_c[1] - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) + noise_augment = adm_c[2] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) + noise_aug.append(noise_augment) adm_inputs.append(adm_out) - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: Apply Noise to Embedding Mix + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + print(noise_level) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) else: adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) x[1] = x[1].copy() diff --git a/nodes.py b/nodes.py index 21b8f5a57..a373bc2f7 100644 --- a/nodes.py +++ b/nodes.py @@ -445,17 +445,18 @@ class unCLIPConditioning: return {"required": {"conditioning": ("CONDITIONING", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" CATEGORY = "_for_testing/unclip" - def apply_adm(self, conditioning, clip_vision_output, strength): + def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): c = [] for t in conditioning: o = t[1].copy() - x = (clip_vision_output, strength) + x = (clip_vision_output, strength, noise_augmentation) if "adm" in o: o["adm"] = o["adm"][:] + [x] else: diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js new file mode 100644 index 000000000..51e66f924 --- /dev/null +++ b/web/extensions/core/contextMenuFilter.js @@ -0,0 +1,137 @@ +import { app } from "/scripts/app.js"; + +// Adds filtering to combo context menus + +const id = "Comfy.ContextMenuFilter"; +app.registerExtension({ + name: id, + init() { + const ctxMenu = LiteGraph.ContextMenu; + LiteGraph.ContextMenu = function (values, options) { + const ctx = ctxMenu.call(this, values, options); + + // If we are a dark menu (only used for combo boxes) then add a filter input + if (options?.className === "dark" && values?.length > 10) { + const filter = document.createElement("input"); + Object.assign(filter.style, { + width: "calc(100% - 10px)", + border: "0", + boxSizing: "border-box", + background: "#333", + border: "1px solid #999", + margin: "0 0 5px 5px", + color: "#fff", + }); + filter.placeholder = "Filter list"; + this.root.prepend(filter); + + let selectedIndex = 0; + let items = this.root.querySelectorAll(".litemenu-entry"); + let itemCount = items.length; + let selectedItem; + + // Apply highlighting to the selected item + function updateSelected() { + if (selectedItem) { + selectedItem.style.setProperty("background-color", ""); + selectedItem.style.setProperty("color", ""); + } + selectedItem = items[selectedIndex]; + if (selectedItem) { + selectedItem.style.setProperty("background-color", "#ccc", "important"); + selectedItem.style.setProperty("color", "#000", "important"); + } + } + + const positionList = () => { + const rect = this.root.getBoundingClientRect(); + + // If the top is off screen then shift the element with scaling applied + if (rect.top < 0) { + const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; + const shift = (this.root.clientHeight * scale) / 2; + this.root.style.top = -shift + "px"; + } + } + + updateSelected(); + + // Arrow up/down to select items + filter.addEventListener("keydown", (e) => { + if (e.key === "ArrowUp") { + if (selectedIndex === 0) { + selectedIndex = itemCount - 1; + } else { + selectedIndex--; + } + updateSelected(); + e.preventDefault(); + } else if (e.key === "ArrowDown") { + if (selectedIndex === itemCount - 1) { + selectedIndex = 0; + } else { + selectedIndex++; + } + updateSelected(); + e.preventDefault(); + } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) { + selectedItem.click(); + } else if(e.key === "Escape") { + this.close(); + } + }); + + filter.addEventListener("input", () => { + // Hide all items that dont match our filter + const term = filter.value.toLocaleLowerCase(); + items = this.root.querySelectorAll(".litemenu-entry"); + // When filtering recompute which items are visible for arrow up/down + // Try and maintain selection + let visibleItems = []; + for (const item of items) { + const visible = !term || item.textContent.toLocaleLowerCase().includes(term); + if (visible) { + item.style.display = "block"; + if (item === selectedItem) { + selectedIndex = visibleItems.length; + } + visibleItems.push(item); + } else { + item.style.display = "none"; + if (item === selectedItem) { + selectedIndex = 0; + } + } + } + items = visibleItems; + updateSelected(); + + // If we have an event then we can try and position the list under the source + if (options.event) { + let top = options.event.clientY - 10; + + const bodyRect = document.body.getBoundingClientRect(); + const rootRect = this.root.getBoundingClientRect(); + if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { + top = Math.max(0, bodyRect.height - rootRect.height - 10); + } + + this.root.style.top = top + "px"; + positionList(); + } + }); + + requestAnimationFrame(() => { + // Focus the filter box when opening + filter.focus(); + + positionList(); + }); + } + + return ctx; + }; + + LiteGraph.ContextMenu.prototype = ctxMenu.prototype; + }, +}); diff --git a/web/extensions/core/dynamicPrompts.js b/web/extensions/core/dynamicPrompts.js index 8528201d3..7dae07f4d 100644 --- a/web/extensions/core/dynamicPrompts.js +++ b/web/extensions/core/dynamicPrompts.js @@ -30,7 +30,8 @@ app.registerExtension({ } // Overwrite the value in the serialized workflow pnginfo - workflowNode.widgets_values[widgetIndex] = prompt; + if (workflowNode?.widgets_values) + workflowNode.widgets_values[widgetIndex] = prompt; return prompt; }; diff --git a/web/extensions/core/invertMenuScrolling.js b/web/extensions/core/invertMenuScrolling.js index 34523d55c..f900fccf4 100644 --- a/web/extensions/core/invertMenuScrolling.js +++ b/web/extensions/core/invertMenuScrolling.js @@ -3,10 +3,10 @@ import { app } from "/scripts/app.js"; // Inverts the scrolling of context menus const id = "Comfy.InvertMenuScrolling"; -const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, init() { + const ctxMenu = LiteGraph.ContextMenu; const replace = () => { LiteGraph.ContextMenu = function (values, options) { options = options || {}; diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 1342cae92..aee5d147c 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -50,12 +50,12 @@ app.registerExtension({ } else { // Move the previous node - currentNode = node; + currentNode = node; } } else { // We've found the end inputNode = currentNode; - inputType = node.outputs[link.origin_slot].type; + inputType = node.outputs[link.origin_slot]?.type ?? null; break; } } else { @@ -87,7 +87,7 @@ app.registerExtension({ updateNodes.push(node); } else { // We've found an output - const nodeOutType = node.inputs[link.target_slot].type; + const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; if (inputType && nodeOutType !== inputType) { // The output doesnt match our input so disconnect it node.disconnectInput(link.target_slot);