Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
ltdrdata 2023-04-04 09:06:32 +09:00 committed by GitHub
commit 82a584cc88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 207 additions and 31 deletions

View File

@ -1,4 +1,4 @@
#Taken from: https://github.com/dbolya/tomesd
import torch
from typing import Tuple, Callable
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if r <= 0:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
rand_idx = idx_buffer.argsort(dim=1)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
num_dst = int((1 / (sx*sy)) * N)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1]))
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2
stride_y = 2
max_downsample = 1
if downsample <= max_downsample:
w = original_w // downsample
h = original_h // downsample
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio)
no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)

View File

@ -348,17 +348,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device))
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: Apply Noise to Embedding Mix
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
print(noise_level)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()

View File

@ -521,17 +521,18 @@ class unCLIPConditioning:
return {"required": {"conditioning": ("CONDITIONING", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_adm"
CATEGORY = "_for_testing/unclip"
def apply_adm(self, conditioning, clip_vision_output, strength):
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
c = []
for t in conditioning:
o = t[1].copy()
x = (clip_vision_output, strength)
x = (clip_vision_output, strength, noise_augmentation)
if "adm" in o:
o["adm"] = o["adm"][:] + [x]
else:

View File

@ -0,0 +1,137 @@
import { app } from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
}
return ctx;
};
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});

View File

@ -30,7 +30,8 @@ app.registerExtension({
}
// Overwrite the value in the serialized workflow pnginfo
workflowNode.widgets_values[widgetIndex] = prompt;
if (workflowNode?.widgets_values)
workflowNode.widgets_values[widgetIndex] = prompt;
return prompt;
};

View File

@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
// Inverts the scrolling of context menus
const id = "Comfy.InvertMenuScrolling";
const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
const replace = () => {
LiteGraph.ContextMenu = function (values, options) {
options = options || {};

View File

@ -50,12 +50,12 @@ app.registerExtension({
}
else {
// Move the previous node
currentNode = node;
currentNode = node;
}
} else {
// We've found the end
inputNode = currentNode;
inputType = node.outputs[link.origin_slot].type;
inputType = node.outputs[link.origin_slot]?.type ?? null;
break;
}
} else {
@ -87,7 +87,7 @@ app.registerExtension({
updateNodes.push(node);
} else {
// We've found an output
const nodeOutType = node.inputs[link.target_slot].type;
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot);