Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
ltdrdata 2023-04-04 09:06:32 +09:00 committed by GitHub
commit 82a584cc88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 207 additions and 31 deletions

View File

@ -1,4 +1,4 @@
#Taken from: https://github.com/dbolya/tomesd
import torch import torch
from typing import Tuple, Callable from typing import Tuple, Callable
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return x return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor, def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int, w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]: no_rand: bool = False) -> Tuple[Callable, Callable]:
""" """
Partitions the tokens into src and dst and merges r tokens from src to dst. Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args: Args:
- metric [B, N, C]: metric to use for similarity - metric [B, N, C]: metric to use for similarity
- w: image width in tokens - w: image width in tokens
@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if r <= 0: if r <= 0:
return do_nothing, do_nothing return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad(): with torch.no_grad():
hsy, wsx = h // sy, w // sx hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src # For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand: if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else: else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
rand_idx = idx_buffer.argsort(dim=1) idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
num_dst = int((1 / (sx*sy)) * N) # Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst b_idx = rand_idx[:, :num_dst, :] # dst
def split(x): def split(x):
C = x.shape[-1] C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True) metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric) a, b = split(metric)
scores = a @ b.transpose(-1, -2) scores = a @ b.transpose(-1, -2)
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src # Can't reduce more than the # tokens in src
r = min(a.shape[1], r) r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1) node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x) src, dst = split(x)
n, t1, c = src.shape n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1) return torch.cat([unm, dst], dim=1)
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape _, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape # Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out return out
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def get_functions(x, ratio, original_shape): def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1])) downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2 stride_x = 2
stride_y = 2 stride_y = 2
max_downsample = 1 max_downsample = 1
if downsample <= max_downsample: if downsample <= max_downsample:
w = original_w // downsample w = int(math.ceil(original_w / downsample))
h = original_h // downsample h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio) r = int(x.shape[1] * ratio)
no_rand = False no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)

View File

@ -348,17 +348,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
if 'adm' in x[1]: if 'adm' in x[1]:
adm_inputs = [] adm_inputs = []
weights = [] weights = []
noise_aug = []
adm_in = x[1]["adm"] adm_in = x[1]["adm"]
for adm_c in adm_in: for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds adm_cond = adm_c[0].image_embeds
weight = adm_c[1] weight = adm_c[1]
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight) weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out) adm_inputs.append(adm_out)
adm_out = torch.stack(adm_inputs).sum(0) if len(noise_aug) > 1:
#TODO: Apply Noise to Embedding Mix adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
print(noise_level)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else: else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy() x[1] = x[1].copy()

View File

@ -521,17 +521,18 @@ class unCLIPConditioning:
return {"required": {"conditioning": ("CONDITIONING", ), return {"required": {"conditioning": ("CONDITIONING", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_adm" FUNCTION = "apply_adm"
CATEGORY = "_for_testing/unclip" CATEGORY = "_for_testing/unclip"
def apply_adm(self, conditioning, clip_vision_output, strength): def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
c = [] c = []
for t in conditioning: for t in conditioning:
o = t[1].copy() o = t[1].copy()
x = (clip_vision_output, strength) x = (clip_vision_output, strength, noise_augmentation)
if "adm" in o: if "adm" in o:
o["adm"] = o["adm"][:] + [x] o["adm"] = o["adm"][:] + [x]
else: else:

View File

@ -0,0 +1,137 @@
import { app } from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
}
return ctx;
};
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});

View File

@ -30,7 +30,8 @@ app.registerExtension({
} }
// Overwrite the value in the serialized workflow pnginfo // Overwrite the value in the serialized workflow pnginfo
workflowNode.widgets_values[widgetIndex] = prompt; if (workflowNode?.widgets_values)
workflowNode.widgets_values[widgetIndex] = prompt;
return prompt; return prompt;
}; };

View File

@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
// Inverts the scrolling of context menus // Inverts the scrolling of context menus
const id = "Comfy.InvertMenuScrolling"; const id = "Comfy.InvertMenuScrolling";
const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({ app.registerExtension({
name: id, name: id,
init() { init() {
const ctxMenu = LiteGraph.ContextMenu;
const replace = () => { const replace = () => {
LiteGraph.ContextMenu = function (values, options) { LiteGraph.ContextMenu = function (values, options) {
options = options || {}; options = options || {};

View File

@ -50,12 +50,12 @@ app.registerExtension({
} }
else { else {
// Move the previous node // Move the previous node
currentNode = node; currentNode = node;
} }
} else { } else {
// We've found the end // We've found the end
inputNode = currentNode; inputNode = currentNode;
inputType = node.outputs[link.origin_slot].type; inputType = node.outputs[link.origin_slot]?.type ?? null;
break; break;
} }
} else { } else {
@ -87,7 +87,7 @@ app.registerExtension({
updateNodes.push(node); updateNodes.push(node);
} else { } else {
// We've found an output // We've found an output
const nodeOutType = node.inputs[link.target_slot].type; const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) { if (inputType && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it // The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot); node.disconnectInput(link.target_slot);