From dba4f3b4fce575994ed718ac31888620e8d6e733 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 19 Nov 2023 06:09:01 -0500 Subject: [PATCH 01/32] Add a RepeatImageBatch node. --- comfy_extras/nodes_images.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 2b8e93001..8cb322327 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -23,7 +23,22 @@ class ImageCrop: img = image[:,y:to_y, x:to_x, :] return (img,) +class RepeatImageBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image/batch" + + def repeat(self, image, amount): + s = image.repeat((amount, 1,1,1)) + return (s,) NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, + "RepeatImageBatch": RepeatImageBatch, } From 31c5ea7b2c79f36d3ebc729acf946ba47b4e5785 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 03:55:51 -0500 Subject: [PATCH 02/32] Add LatentInterpolate to interpolate between latents. --- comfy_extras/nodes_latent.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 001de39fc..cedf39d63 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -1,4 +1,5 @@ import comfy.utils +import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: @@ -67,8 +68,43 @@ class LatentMultiply: samples_out["samples"] = s1 * multiplier return (samples_out,) +class LatentInterpolate: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), + "samples2": ("LATENT",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, + "LatentInterpolate": LatentInterpolate, } From a03dde190ede39675736e746c3045ecfc4baa79b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 16:38:39 -0500 Subject: [PATCH 03/32] Cap maximum history size at 10000. Delete oldest entry when reached. --- execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/execution.py b/execution.py index 918c2bc5c..9a2ca5b9d 100644 --- a/execution.py +++ b/execution.py @@ -681,6 +681,7 @@ def validate_prompt(prompt): return (True, None, list(good_outputs), node_errors) +MAXIMUM_HISTORY_SIZE = 10000 class PromptQueue: def __init__(self, server): @@ -713,6 +714,8 @@ class PromptQueue: def task_done(self, item_id, outputs): with self.mutex: prompt = self.currently_running.pop(item_id) + if len(self.history) > MAXIMUM_HISTORY_SIZE: + self.history.pop(next(iter(self.history))) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: self.history[prompt[1]]["outputs"][o] = outputs[o] From 2dd5b4dd78fc0a30f3d5baa0b99a6b10f002d917 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 16:51:41 -0500 Subject: [PATCH 04/32] Only show last 200 elements in the UI history tab. --- execution.py | 14 ++++++++++++-- server.py | 5 ++++- web/scripts/api.js | 2 +- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index 9a2ca5b9d..bca48a785 100644 --- a/execution.py +++ b/execution.py @@ -750,10 +750,20 @@ class PromptQueue: return True return False - def get_history(self, prompt_id=None): + def get_history(self, prompt_id=None, max_items=None, offset=-1): with self.mutex: if prompt_id is None: - return copy.deepcopy(self.history) + out = {} + i = 0 + if offset < 0 and max_items is not None: + offset = len(self.history) - max_items + for k in self.history: + if i >= offset: + out[k] = self.history[k] + if max_items is not None and len(out) >= max_items: + break + i += 1 + return out elif prompt_id in self.history: return {prompt_id: copy.deepcopy(self.history[prompt_id])} else: diff --git a/server.py b/server.py index 11bd2a0fb..1a8e92b8f 100644 --- a/server.py +++ b/server.py @@ -431,7 +431,10 @@ class PromptServer(): @routes.get("/history") async def get_history(request): - return web.json_response(self.prompt_queue.get_history()) + max_items = request.rel_url.query.get("max_items", None) + if max_items is not None: + max_items = int(max_items) + return web.json_response(self.prompt_queue.get_history(max_items=max_items)) @routes.get("/history/{prompt_id}") async def get_history(request): diff --git a/web/scripts/api.js b/web/scripts/api.js index b1d245d73..de56b2310 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -256,7 +256,7 @@ class ComfyApi extends EventTarget { */ async getHistory() { try { - const res = await this.fetchApi("/history"); + const res = await this.fetchApi("/history?max_items=200"); return { History: Object.values(await res.json()) }; } catch (error) { console.error(error); From ce67dcbcdabe2edf1497e37ecf1b6f976a3ecdf6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 22:27:36 -0500 Subject: [PATCH 05/32] Make it easy for models to process the unet state dict on load. --- comfy/model_base.py | 1 + comfy/supported_models_base.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 37bf24bb8..772e26934 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: print("unet missing:", m) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 88a1d7fde..6dfae0343 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -53,6 +53,9 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_unet_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) From 6ff06fa7960524749d8e584100a0e50594485f29 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 21 Nov 2023 06:33:58 +0000 Subject: [PATCH 06/32] Animated image output support (#2008) * Refactor multiline widget into generic DOM widget * wip webp preview * webp support * fix check * fix sizing * show image when zoomed out * Swap webp checkto generic animated image flag * remove duplicate * Fix falsy check --- web/scripts/app.js | 78 +++++---- web/scripts/domWidget.js | 312 +++++++++++++++++++++++++++++++++ web/scripts/ui/imagePreview.js | 97 ++++++++++ web/scripts/widgets.js | 166 ++---------------- web/style.css | 15 ++ 5 files changed, 482 insertions(+), 186 deletions(-) create mode 100644 web/scripts/domWidget.js create mode 100644 web/scripts/ui/imagePreview.js diff --git a/web/scripts/app.js b/web/scripts/app.js index 4507527f6..601e486e6 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -4,7 +4,10 @@ import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; +import { addDomClippingSetting } from "./domWidget.js"; +import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js" +export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview" function sanitizeNodeName(string) { let entityMap = { @@ -405,7 +408,9 @@ export class ComfyApp { return shiftY; } - node.prototype.setSizeForImage = function () { + node.prototype.setSizeForImage = function (force) { + if(!force && this.animatedImages) return; + if (this.inputHeight) { this.setSize(this.size); return; @@ -422,13 +427,20 @@ export class ComfyApp { let imagesChanged = false const output = app.nodeOutputs[this.id + ""]; - if (output && output.images) { + if (output?.images) { + this.animatedImages = output?.animated?.find(Boolean); if (this.images !== output.images) { this.images = output.images; imagesChanged = true; - imgURLs = imgURLs.concat(output.images.map(params => { - return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam()); - })) + imgURLs = imgURLs.concat( + output.images.map((params) => { + return api.apiURL( + "/view?" + + new URLSearchParams(params).toString() + + (this.animatedImages ? "" : app.getPreviewFormatParam()) + ); + }) + ); } } @@ -507,7 +519,34 @@ export class ComfyApp { return true; } - if (this.imgs && this.imgs.length) { + if (this.imgs?.length) { + const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET); + + if(this.animatedImages) { + // Instead of using the canvas we'll use a IMG + if(widgetIdx > -1) { + // Replace content + const widget = this.widgets[widgetIdx]; + widget.options.host.updateImages(this.imgs); + } else { + const host = createImageHost(this); + this.setSizeForImage(true); + const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, { + host, + getHeight: host.getHeight, + onDraw: host.onDraw, + hideOnZoom: false + }); + widget.serializeValue = () => undefined; + widget.options.host.updateImages(this.imgs); + } + return; + } + + if (widgetIdx > -1) { + this.widgets.splice(widgetIdx, 1); + } + const canvas = app.graph.list_of_graphcanvas[0]; const mouse = canvas.graph_mouse; if (!canvas.pointer_is_down && this.pointerDown) { @@ -547,31 +586,7 @@ export class ComfyApp { } else { cell_padding = 0; - let best = 0; - let w = this.imgs[0].naturalWidth; - let h = this.imgs[0].naturalHeight; - - // compact style - for (let c = 1; c <= numImages; c++) { - const rows = Math.ceil(numImages / c); - const cW = dw / c; - const cH = dh / rows; - const scaleX = cW / w; - const scaleY = cH / h; - - const scale = Math.min(scaleX, scaleY, 1); - const imageW = w * scale; - const imageH = h * scale; - const area = imageW * imageH * numImages; - - if (area > best) { - best = area; - cellWidth = imageW; - cellHeight = imageH; - cols = c; - shiftX = c * ((cW - imageW) / 2); - } - } + ({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh)); } let anyHovered = false; @@ -1272,6 +1287,7 @@ export class ComfyApp { canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); + addDomClippingSetting(); this.#addProcessMouseHandler(); this.#addProcessKeyHandler(); this.#addConfigureHandler(); diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js new file mode 100644 index 000000000..16f4e192e --- /dev/null +++ b/web/scripts/domWidget.js @@ -0,0 +1,312 @@ +import { app, ANIM_PREVIEW_WIDGET } from "./app.js"; + +const SIZE = Symbol(); + +function intersect(a, b) { + const x = Math.max(a.x, b.x); + const num1 = Math.min(a.x + a.width, b.x + b.width); + const y = Math.max(a.y, b.y); + const num2 = Math.min(a.y + a.height, b.y + b.height); + if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y]; + else return null; +} + +function getClipPath(node, element, elRect) { + const selectedNode = Object.values(app.canvas.selected_nodes)[0]; + if (selectedNode && selectedNode !== node) { + const MARGIN = 7; + const scale = app.canvas.ds.scale; + + const intersection = intersect( + { x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale }, + { + x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN, + y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN, + width: selectedNode.size[0] + MARGIN + MARGIN, + height: selectedNode.size[1] + LiteGraph.NODE_TITLE_HEIGHT + MARGIN + MARGIN, + } + ); + + if (!intersection) { + return ""; + } + + const widgetRect = element.getBoundingClientRect(); + const clipX = intersection[0] - widgetRect.x / scale + "px"; + const clipY = intersection[1] - widgetRect.y / scale + "px"; + const clipWidth = intersection[2] + "px"; + const clipHeight = intersection[3] + "px"; + const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`; + return path; + } + return ""; +} + +function computeSize(size) { + if (this.widgets?.[0].last_y == null) return; + + let y = this.widgets[0].last_y; + let freeSpace = size[1] - y; + + let widgetHeight = 0; + let dom = []; + for (const w of this.widgets) { + if (w.type === "converted-widget") { + // Ignore + delete w.computedHeight; + } else if (w.computeSize) { + widgetHeight += w.computeSize()[1] + 4; + } else if (w.element) { + // Extract DOM widget size info + const styles = getComputedStyle(w.element); + let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height")); + let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height")); + + let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height"); + if (prefHeight.endsWith?.("%")) { + prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100); + } else { + prefHeight = parseInt(prefHeight); + if (isNaN(minHeight)) { + minHeight = prefHeight; + } + } + if (isNaN(minHeight)) { + minHeight = 50; + } + if (!isNaN(maxHeight)) { + if (!isNaN(prefHeight)) { + prefHeight = Math.min(prefHeight, maxHeight); + } else { + prefHeight = maxHeight; + } + } + dom.push({ + minHeight, + prefHeight, + w, + }); + } else { + widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } + + freeSpace -= widgetHeight; + + // Calculate sizes with all widgets at their min height + const prefGrow = []; // Nodes that want to grow to their prefd size + const canGrow = []; // Nodes that can grow to auto size + let growBy = 0; + for (const d of dom) { + freeSpace -= d.minHeight; + if (isNaN(d.prefHeight)) { + canGrow.push(d); + d.w.computedHeight = d.minHeight; + } else { + const diff = d.prefHeight - d.minHeight; + if (diff > 0) { + prefGrow.push(d); + growBy += diff; + d.diff = diff; + } else { + d.w.computedHeight = d.minHeight; + } + } + } + + if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) { + // Allocate space for image + freeSpace -= 220; + } + + if (freeSpace < 0) { + // Not enough space for all widgets so we need to grow + size[1] -= freeSpace; + this.graph.setDirtyCanvas(true); + } else { + // Share the space between each + const growDiff = freeSpace - growBy; + if (growDiff > 0) { + // All pref sizes can be fulfilled + freeSpace = growDiff; + for (const d of prefGrow) { + d.w.computedHeight = d.prefHeight; + } + } else { + // We need to grow evenly + const shared = -growDiff / prefGrow.length; + for (const d of prefGrow) { + d.w.computedHeight = d.prefHeight - shared; + } + freeSpace = 0; + } + + if (freeSpace > 0 && canGrow.length) { + // Grow any that are auto height + const shared = freeSpace / canGrow.length; + for (const d of canGrow) { + d.w.computedHeight += shared; + } + } + } + + // Position each of the widgets + for (const w of this.widgets) { + w.y = y; + if (w.computedHeight) { + y += w.computedHeight; + } else if (w.computeSize) { + y += w.computeSize()[1] + 4; + } else { + y += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } +} + +// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen +const elementWidgets = new Set(); +const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes; +LGraphCanvas.prototype.computeVisibleNodes = function () { + const visibleNodes = computeVisibleNodes.apply(this, arguments); + for (const node of app.graph._nodes) { + if (elementWidgets.has(node)) { + const hidden = visibleNodes.indexOf(node) === -1; + for (const w of node.widgets) { + if (w.element) { + w.element.hidden = hidden; + if (hidden) { + w.options.onHide?.(w); + } + } + } + } + } + + return visibleNodes; +}; + +let enableDomClipping = true; + +export function addDomClippingSetting() { + app.ui.settings.addSetting({ + id: "Comfy.DOMClippingEnabled", + name: "Enable DOM element clipping (enabling may reduce performance)", + type: "boolean", + defaultValue: enableDomClipping, + onChange(value) { + console.log("enableDomClipping", enableDomClipping); + enableDomClipping = !!value; + }, + }); +} + +LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { + options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options }; + + if (!element.parentElement) { + document.body.append(element); + } + + let mouseDownHandler; + if (element.blur) { + mouseDownHandler = (event) => { + if (!element.contains(event.target)) { + element.blur(); + } + }; + document.addEventListener("mousedown", mouseDownHandler); + } + + const widget = { + type, + name, + get value() { + return options.getValue?.() ?? undefined; + }, + set value(v) { + options.setValue?.(v); + widget.callback?.(widget.value); + }, + draw: function (ctx, node, widgetWidth, y, widgetHeight) { + if (widget.computedHeight == null) { + computeSize.call(node, node.size); + } + + const hidden = + (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || + widget.computedHeight <= 0 || + widget.type === "converted-widget"; + element.hidden = hidden; + element.style.display = hidden ? "none" : null; + if (hidden) { + widget.options.onHide?.(widget); + return; + } + + const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + + Object.assign(element.style, { + transformOrigin: "0 0", + transform: scale, + left: `${transform.a + transform.e}px`, + top: `${transform.d + transform.f}px`, + width: `${widgetWidth - margin * 2}px`, + height: `${(widget.computedHeight ?? 50) - margin * 2}px`, + position: "absolute", + zIndex: app.graph._nodes.indexOf(node), + }); + + if (enableDomClipping) { + element.style.clipPath = getClipPath(node, element, elRect); + element.style.willChange = "clip-path"; + } + + this.options.onDraw?.(widget); + }, + element, + options, + onRemove() { + if (mouseDownHandler) { + document.removeEventListener("mousedown", mouseDownHandler); + } + element.remove(); + }, + }; + + for (const evt of options.selectOn) { + element.addEventListener(evt, () => { + app.canvas.selectNode(this); + app.canvas.bringToFront(this); + }); + } + + this.addCustomWidget(widget); + elementWidgets.add(this); + + const onRemoved = this.onRemoved; + this.onRemoved = function () { + element.remove(); + elementWidgets.delete(this); + onRemoved?.apply(this, arguments); + }; + + if (!this[SIZE]) { + this[SIZE] = true; + const onResize = this.onResize; + this.onResize = function (size) { + options.beforeResize?.call(widget, this); + computeSize.call(this, size); + onResize?.apply(this, arguments); + options.afterResize?.call(widget, this); + }; + } + + return widget; +}; diff --git a/web/scripts/ui/imagePreview.js b/web/scripts/ui/imagePreview.js new file mode 100644 index 000000000..2a7f66b8f --- /dev/null +++ b/web/scripts/ui/imagePreview.js @@ -0,0 +1,97 @@ +import { $el } from "../ui.js"; + +export function calculateImageGrid(imgs, dw, dh) { + let best = 0; + let w = imgs[0].naturalWidth; + let h = imgs[0].naturalHeight; + const numImages = imgs.length; + + let cellWidth, cellHeight, cols, rows, shiftX; + // compact style + for (let c = 1; c <= numImages; c++) { + const r = Math.ceil(numImages / c); + const cW = dw / c; + const cH = dh / r; + const scaleX = cW / w; + const scaleY = cH / h; + + const scale = Math.min(scaleX, scaleY, 1); + const imageW = w * scale; + const imageH = h * scale; + const area = imageW * imageH * numImages; + + if (area > best) { + best = area; + cellWidth = imageW; + cellHeight = imageH; + cols = c; + rows = r; + shiftX = c * ((cW - imageW) / 2); + } + } + + return { cellWidth, cellHeight, cols, rows, shiftX }; +} + +export function createImageHost(node) { + const el = $el("div.comfy-img-preview"); + let currentImgs; + let first = true; + + function updateSize() { + let w = null; + let h = null; + + if (currentImgs) { + let elH = el.clientHeight; + if (first) { + first = false; + // On first run, if we are small then grow a bit + if (elH < 190) { + elH = 190; + } + el.style.setProperty("--comfy-widget-min-height", elH); + } else { + el.style.setProperty("--comfy-widget-min-height", null); + } + + const nw = node.size[0]; + ({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH)); + w += "px"; + h += "px"; + + el.style.setProperty("--comfy-img-preview-width", w); + el.style.setProperty("--comfy-img-preview-height", h); + } + } + return { + el, + updateImages(imgs) { + if (imgs !== currentImgs) { + if (currentImgs == null) { + requestAnimationFrame(() => { + updateSize(); + }); + } + el.replaceChildren(...imgs); + currentImgs = imgs; + node.onResize(node.size); + node.graph.setDirtyCanvas(true, true); + } + }, + getHeight() { + updateSize(); + }, + onDraw() { + // Element from point uses a hittest find elements so we need to toggle pointer events + el.style.pointerEvents = "all"; + const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]); + el.style.pointerEvents = "none"; + + if(!over) return; + // Set the overIndex so Open Image etc work + const idx = currentImgs.indexOf(over); + node.overIndex = idx; + }, + }; +} diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 36bc7ff7f..ccddc0bc4 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,4 +1,5 @@ import { api } from "./api.js" +import "./domWidget.js"; function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; @@ -97,166 +98,21 @@ function seedWidget(node, inputName, inputData, app) { seed.widget.linkedWidgets = [seedControl]; return seed; } - -const MultilineSymbol = Symbol(); -const MultilineResizeSymbol = Symbol(); - function addMultilineWidget(node, name, opts, app) { - const MIN_SIZE = 50; + const inputEl = document.createElement("textarea"); + inputEl.className = "comfy-multiline-input"; + inputEl.value = opts.defaultVal; + inputEl.placeholder = opts.placeholder || ""; - function computeSize(size) { - if (node.widgets[0].last_y == null) return; - - let y = node.widgets[0].last_y; - let freeSpace = size[1] - y; - - // Compute the height of all non customtext widgets - let widgetHeight = 0; - const multi = []; - for (let i = 0; i < node.widgets.length; i++) { - const w = node.widgets[i]; - if (w.type === "customtext") { - multi.push(w); - } else { - if (w.computeSize) { - widgetHeight += w.computeSize()[1] + 4; - } else { - widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; - } - } - } - - // See how large each text input can be - freeSpace -= widgetHeight; - freeSpace /= multi.length + (!!node.imgs?.length); - - if (freeSpace < MIN_SIZE) { - // There isnt enough space for all the widgets, increase the size of the node - freeSpace = MIN_SIZE; - node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length)); - node.graph.setDirtyCanvas(true); - } - - // Position each of the widgets - for (const w of node.widgets) { - w.y = y; - if (w.type === "customtext") { - y += freeSpace; - w.computedHeight = freeSpace - multi.length*4; - } else if (w.computeSize) { - y += w.computeSize()[1] + 4; - } else { - y += LiteGraph.NODE_WIDGET_HEIGHT + 4; - } - } - - node.inputHeight = freeSpace; - } - - const widget = { - type: "customtext", - name, - get value() { - return this.inputEl.value; + const widget = node.addDOMWidget(name, "customtext", inputEl, { + getValue() { + return inputEl.value; }, - set value(x) { - this.inputEl.value = x; + setValue(v) { + inputEl.value = v; }, - draw: function (ctx, _, widgetWidth, y, widgetHeight) { - if (!this.parent.inputHeight) { - // If we are initially offscreen when created we wont have received a resize event - // Calculate it here instead - computeSize(node.size); - } - const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const margin = 10; - const elRect = ctx.canvas.getBoundingClientRect(); - const transform = new DOMMatrix() - .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) - .multiplySelf(ctx.getTransform()) - .translateSelf(margin, margin + y); - - const scale = new DOMMatrix().scaleSelf(transform.a, transform.d) - Object.assign(this.inputEl.style, { - transformOrigin: "0 0", - transform: scale, - left: `${transform.a + transform.e}px`, - top: `${transform.d + transform.f}px`, - width: `${widgetWidth - (margin * 2)}px`, - height: `${this.parent.inputHeight - (margin * 2)}px`, - position: "absolute", - background: (!node.color)?'':node.color, - color: (!node.color)?'':'white', - zIndex: app.graph._nodes.indexOf(node), - }); - this.inputEl.hidden = !visible; - }, - }; - widget.inputEl = document.createElement("textarea"); - widget.inputEl.className = "comfy-multiline-input"; - widget.inputEl.value = opts.defaultVal; - widget.inputEl.placeholder = opts.placeholder || ""; - document.addEventListener("mousedown", function (event) { - if (!widget.inputEl.contains(event.target)) { - widget.inputEl.blur(); - } }); - widget.parent = node; - document.body.appendChild(widget.inputEl); - - node.addCustomWidget(widget); - - app.canvas.onDrawBackground = function () { - // Draw node isnt fired once the node is off the screen - // if it goes off screen quickly, the input may not be removed - // this shifts it off screen so it can be moved back if the node is visible. - for (let n in app.graph._nodes) { - n = graph._nodes[n]; - for (let w in n.widgets) { - let wid = n.widgets[w]; - if (Object.hasOwn(wid, "inputEl")) { - wid.inputEl.style.left = -8000 + "px"; - wid.inputEl.style.position = "absolute"; - } - } - } - }; - - node.onRemoved = function () { - // When removing this node we need to remove the input from the DOM - for (let y in this.widgets) { - if (this.widgets[y].inputEl) { - this.widgets[y].inputEl.remove(); - } - } - }; - - widget.onRemove = () => { - widget.inputEl?.remove(); - - // Restore original size handler if we are the last - if (!--node[MultilineSymbol]) { - node.onResize = node[MultilineResizeSymbol]; - delete node[MultilineSymbol]; - delete node[MultilineResizeSymbol]; - } - }; - - if (node[MultilineSymbol]) { - node[MultilineSymbol]++; - } else { - node[MultilineSymbol] = 1; - const onResize = (node[MultilineResizeSymbol] = node.onResize); - - node.onResize = function (size) { - computeSize(size); - - // Call original resizer handler - if (onResize) { - onResize.apply(this, arguments); - } - }; - } + widget.inputEl = inputEl; return { minWidth: 400, minHeight: 200, widget }; } diff --git a/web/style.css b/web/style.css index 692fa31d6..378fe0a48 100644 --- a/web/style.css +++ b/web/style.css @@ -409,6 +409,21 @@ dialog::backdrop { width: calc(100% - 10px); } +.comfy-img-preview { + pointer-events: none; + overflow: hidden; + display: flex; + flex-wrap: wrap; + align-content: flex-start; + justify-content: center; +} + +.comfy-img-preview img { + object-fit: contain; + width: var(--comfy-img-preview-width); + height: var(--comfy-img-preview-height); +} + /* Search box */ .litegraph.litesearchbox { From 89e31abc46df00d10d48b8a4e36256fefd5973ed Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:54:01 +0000 Subject: [PATCH 07/32] Fix clipping of collapsed nodes --- web/scripts/domWidget.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index 16f4e192e..2f73e573e 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -17,13 +17,14 @@ function getClipPath(node, element, elRect) { const MARGIN = 7; const scale = app.canvas.ds.scale; + const bounding = selectedNode.getBounding(); const intersection = intersect( { x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale }, { x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN, y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN, - width: selectedNode.size[0] + MARGIN + MARGIN, - height: selectedNode.size[1] + LiteGraph.NODE_TITLE_HEIGHT + MARGIN + MARGIN, + width: bounding[2] + MARGIN + MARGIN, + height: bounding[3] + MARGIN + MARGIN, } ); From cd4fc77d5f83867cdfb806f0c96c65ce8a84322c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 12:54:19 -0500 Subject: [PATCH 08/32] Add taesd and taesdxl to VAELoader node. They will show up if both the taesd_encoder and taesd_decoder or taesdxl model files are present in the models/vae_approx directory. --- comfy/sd.py | 17 ++++++++++---- comfy/taesd/taesd.py | 19 +++++++++++---- latent_preview.py | 5 +--- nodes.py | 55 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 79 insertions(+), 17 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c3cc8e720..0f83cc581 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -23,6 +23,7 @@ import comfy.model_patcher import comfy.lora import comfy.t2i_adapter.adapter import comfy.supported_models_base +import comfy.taesd.taesd def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -154,10 +155,16 @@ class VAE: if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) + self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking + self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7 + if config is None: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + if "taesd_decoder.1.weight" in sd: + self.first_stage_model = comfy.taesd.taesd.TAESD() + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() @@ -206,7 +213,7 @@ class VAE: def decode(self, samples_in): self.first_stage_model = self.first_stage_model.to(self.device) try: - memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 + memory_used = self.memory_used_decode(samples_in.shape) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) @@ -234,7 +241,7 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 8df1f1609..46f3097a2 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -46,15 +46,16 @@ class TAESD(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): + def __init__(self, encoder_path=None, decoder_path=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.encoder = Encoder() - self.decoder = Decoder() + self.taesd_encoder = Encoder() + self.taesd_decoder = Decoder() + self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) if encoder_path is not None: - self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) + self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) if decoder_path is not None: - self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) + self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) @staticmethod def scale_latents(x): @@ -65,3 +66,11 @@ class TAESD(nn.Module): def unscale_latents(x): """[0, 1] -> raw latents""" return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + def decode(self, x): + x_sample = self.taesd_decoder(x * self.vae_scale) + x_sample = x_sample.sub(0.5).mul(2) + return x_sample + + def encode(self, x): + return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale diff --git a/latent_preview.py b/latent_preview.py index 6e758a1a9..61754751e 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer): self.taesd = taesd def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decoder(x0[:1])[0].detach() - # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] - x_sample = x_sample.sub(0.5).mul(2) - + x_sample = self.taesd.decode(x0[:1])[0].detach() x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/nodes.py b/nodes.py index 2adc5e073..2de468da7 100644 --- a/nodes.py +++ b/nodes.py @@ -573,9 +573,55 @@ class LoraLoader: return (model_lora, clip_lora) class VAELoader: + @staticmethod + def vae_list(): + vaes = folder_paths.get_filename_list("vae") + approx_vaes = folder_paths.get_filename_list("vae_approx") + sdxl_taesd_enc = False + sdxl_taesd_dec = False + sd1_taesd_enc = False + sd1_taesd_dec = False + + for v in approx_vaes: + if v.startswith("taesd_decoder."): + sd1_taesd_dec = True + elif v.startswith("taesd_encoder."): + sd1_taesd_enc = True + elif v.startswith("taesdxl_decoder."): + sdxl_taesd_dec = True + elif v.startswith("taesdxl_encoder."): + sdxl_taesd_enc = True + if sd1_taesd_dec and sd1_taesd_enc: + vaes.append("taesd") + if sdxl_taesd_dec and sdxl_taesd_enc: + vaes.append("taesdxl") + return vaes + + @staticmethod + def load_taesd(name): + sd = {} + approx_vaes = folder_paths.get_filename_list("vae_approx") + + encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) + decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) + + enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) + for k in enc: + sd["taesd_encoder.{}".format(k)] = enc[k] + + dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) + for k in dec: + sd["taesd_decoder.{}".format(k)] = dec[k] + + if name == "taesd": + sd["vae_scale"] = torch.tensor(0.18215) + elif name == "taesdxl": + sd["vae_scale"] = torch.tensor(0.13025) + return sd + @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} + return {"required": { "vae_name": (s.vae_list(), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -583,8 +629,11 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - vae_path = folder_paths.get_full_path("vae", vae_name) - sd = comfy.utils.load_torch_file(vae_path) + if vae_name in ["taesd", "taesdxl"]: + sd = self.load_taesd(vae_name) + else: + vae_path = folder_paths.get_full_path("vae", vae_name) + sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) return (vae,) From 6a491ebe2729c675322491e255a72d5ac0ef5bf6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 16:29:18 -0500 Subject: [PATCH 09/32] Allow model config to preprocess the vae state dict on load. --- comfy/sd.py | 1 + comfy/supported_models_base.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 0f83cc581..c006a0362 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -448,6 +448,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd) if output_clip: diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 6dfae0343..b073eb4fc 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -56,6 +56,9 @@ class BASE: def process_unet_state_dict(self, state_dict): return state_dict + def process_vae_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) From 72741105a687c67137eb5d7a38840b8373d82e61 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 17:18:49 -0500 Subject: [PATCH 10/32] Remove useless code. --- .../modules/diffusionmodules/openaimodel.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 10eb68d73..e8f35a540 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -28,25 +28,6 @@ class TimestepBlock(nn.Module): Apply the module to `x` given `emb` timestep embeddings. """ - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context, transformer_options) - elif isinstance(layer, Upsample): - x = layer(x, output_shape=output_shape) - else: - x = layer(x) - return x - #This is needed because accelerate makes a copy of transformer_options which breaks "current_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): for layer in ts: @@ -54,13 +35,23 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) - transformer_options["current_index"] += 1 + if "current_index" in transformer_options: + transformer_options["current_index"] += 1 elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: x = layer(x) return x +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, *args, **kwargs): + return forward_timestep_embed(self, *args, **kwargs) + class Upsample(nn.Module): """ An upsampling layer with an optional convolution. From c3ae99a749fa1e9a6dbb96c69c65c6fcf2507af3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 03:23:16 -0500 Subject: [PATCH 11/32] Allow controlling downscale and upscale methods in PatchModelAddDownscale. --- comfy/utils.py | 6 ++++-- comfy_extras/nodes_model_downscale.py | 10 +++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index f4c0ab419..294bbb425 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -318,7 +318,9 @@ def bislerp(samples, width, height): coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 - + + orig_dtype = samples.dtype + samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) @@ -347,7 +349,7 @@ def bislerp(samples, width, height): result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) - return result + return result.to(orig_dtype) def lanczos(samples, width, height): images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f65ef05e1..48bcc6892 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -1,6 +1,8 @@ import torch +import comfy.utils class PatchModelAddDownscale: + upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -9,13 +11,15 @@ class PatchModelAddDownscale: "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), "downscale_after_skip": ("BOOLEAN", {"default": True}), + "downscale_method": (s.upscale_methods,), + "upscale_method": (s.upscale_methods,), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" - def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) @@ -23,12 +27,12 @@ class PatchModelAddDownscale: if transformer_options["block"][1] == block_number: sigma = transformer_options["sigmas"][0].item() if sigma <= sigma_start and sigma >= sigma_end: - h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) + h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") return h def output_block_patch(h, hsp, transformer_options): if h.shape[2] != hsp.shape[2]: - h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) + h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") return h, hsp m = model.clone() From ab7d4f784892c275e888d71aa80a3a2ed59d9b83 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 22 Nov 2023 13:53:30 +0000 Subject: [PATCH 12/32] Handle collapsing to hide element --- web/scripts/domWidget.js | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index 16f4e192e..0f8a2eb01 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -233,6 +233,7 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { } const hidden = + node.flags?.collapsed || (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.computedHeight <= 0 || widget.type === "converted-widget"; @@ -290,6 +291,15 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { this.addCustomWidget(widget); elementWidgets.add(this); + const collapse = this.collapse; + this.collapse = function() { + collapse.apply(this, arguments); + if(this.flags?.collapsed) { + element.hidden = true; + element.style.display = "none"; + } + } + const onRemoved = this.onRemoved; this.onRemoved = function () { element.remove(); From 70d2ea0faa28e1727f7535466ac5378e786b32cb Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:52:20 +0000 Subject: [PATCH 13/32] Control filter list (#2009) * Add control_filter_list to filter items after queue * fix regex * backwards compatibility * formatting * revert * Add and fix test --- tests-ui/tests/widgetInputs.test.js | 96 ++++++++++++++++++++++++++--- web/extensions/core/widgetInputs.js | 8 ++- web/scripts/widgets.js | 56 ++++++++++++++--- 3 files changed, 141 insertions(+), 19 deletions(-) diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 022e54926..e1873105a 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -14,10 +14,10 @@ const lg = require("../utils/litegraph"); * @param { InstanceType } graph * @param { InstanceType } input * @param { string } widgetType - * @param { boolean } hasControlWidget + * @param { number } controlWidgetCount * @returns */ -async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) { +async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) { // Connect to primitive and ensure its still connected after let primitive = ez.PrimitiveNode(); primitive.outputs[0].connectTo(input); @@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro expect(valueWidget.widget.type).toBe(widgetType); // Check if control_after_generate should be added - if (hasControlWidget) { + if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); + if(widgetType === "combo") { + const filterWidget = primitive.widgets.control_filter_list; + expect(filterWidget.widget.type).toBe("string"); + } } // Ensure we dont have other widgets - expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget); + expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount); }); return primitive; @@ -55,8 +59,8 @@ describe("widget inputs", () => { }); [ - { name: "int", type: "INT", widget: "number", control: true }, - { name: "float", type: "FLOAT", widget: "number", control: true }, + { name: "int", type: "INT", widget: "number", control: 1 }, + { name: "float", type: "FLOAT", widget: "number", control: 1 }, { name: "text", type: "STRING" }, { name: "customtext", @@ -64,7 +68,7 @@ describe("widget inputs", () => { opt: { multiline: true }, }, { name: "toggle", type: "BOOLEAN" }, - { name: "combo", type: ["a", "b", "c"], control: true }, + { name: "combo", type: ["a", "b", "c"], control: 2 }, ].forEach((c) => { test(`widget conversion + primitive works on ${c.name}`, async () => { const { ez, graph } = await start({ @@ -106,7 +110,7 @@ describe("widget inputs", () => { n.widgets.ckpt_name.convertToInput(); expect(n.inputs.length).toEqual(inputCount + 1); - const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true); + const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2); // Disconnect & reconnect primitive.outputs[0].connections[0].disconnect(); @@ -226,7 +230,7 @@ describe("widget inputs", () => { // Reload and ensure it still only has 1 converted widget if (!assertNotNullOrUndefined(input)) return; - await connectPrimitiveAndReload(ez, graph, input, "number", true); + await connectPrimitiveAndReload(ez, graph, input, "number", 1); n = graph.find(n); expect(n.widgets).toHaveLength(1); w = n.widgets.example; @@ -258,7 +262,7 @@ describe("widget inputs", () => { // Reload and ensure it still only has 1 converted widget if (assertNotNullOrUndefined(input)) { - await connectPrimitiveAndReload(ez, graph, input, "number", true); + await connectPrimitiveAndReload(ez, graph, input, "number", 1); n = graph.find(n); expect(n.widgets).toHaveLength(1); expect(n.widgets.example.isConvertedToInput).toBeTruthy(); @@ -316,4 +320,76 @@ describe("widget inputs", () => { n1.outputs[0].connectTo(n2.inputs[0]); expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow(); }); + + test("combo primitive can filter list when control_after_generate called", async () => { + const { ez } = await start({ + mockNodeDefs: { + ...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }), + }, + }); + + const n1 = ez.TestNode1(); + n1.widgets.example.convertToInput(); + const p = ez.PrimitiveNode() + p.outputs[0].connectTo(n1.inputs[0]); + + const value = p.widgets.value; + const control = p.widgets.control_after_generate.widget; + const filter = p.widgets.control_filter_list; + + expect(p.widgets.length).toBe(3); + control.value = "increment"; + expect(value.value).toBe("A"); + + // Manually trigger after queue when set to increment + control["afterQueued"](); + expect(value.value).toBe("B"); + + // Filter to items containing D + filter.value = "D"; + control["afterQueued"](); + expect(value.value).toBe("D"); + control["afterQueued"](); + expect(value.value).toBe("DD"); + + // Check decrement + value.value = "BBB"; + control.value = "decrement"; + filter.value = "B"; + control["afterQueued"](); + expect(value.value).toBe("BB"); + control["afterQueued"](); + expect(value.value).toBe("B"); + + // Check regex works + value.value = "BBB"; + filter.value = "/[AB]|^C$/"; + control["afterQueued"](); + expect(value.value).toBe("AAA"); + control["afterQueued"](); + expect(value.value).toBe("BB"); + control["afterQueued"](); + expect(value.value).toBe("AA"); + control["afterQueued"](); + expect(value.value).toBe("C"); + control["afterQueued"](); + expect(value.value).toBe("B"); + control["afterQueued"](); + expect(value.value).toBe("A"); + + // Check random + control.value = "randomize"; + filter.value = "/D/"; + for(let i = 0; i < 100; i++) { + control["afterQueued"](); + expect(value.value === "D" || value.value === "DD").toBeTruthy(); + } + + // Ensure it doesnt apply when fixed + control.value = "fixed"; + value.value = "B"; + filter.value = "C"; + control["afterQueued"](); + expect(value.value).toBe("B"); + }); }); diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index bad3ac3a7..5c8fbc9b2 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -1,4 +1,4 @@ -import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; +import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; @@ -467,7 +467,11 @@ app.registerExtension({ if (!control_value) { control_value = "fixed"; } - addValueControlWidget(this, widget, control_value); + addValueControlWidgets(this, widget, control_value); + let filter = this.widgets_values?.[2]; + if(filter && this.widgets.length === 3) { + this.widgets[2].value = filter; + } } // When our value changes, update other widgets to reflect our changes diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index ccddc0bc4..fbc1d0fc3 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -24,17 +24,58 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { } export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { - const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { + const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, { + addFilterList: false, + }); + return widgets[0]; +} + +export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) { + if (!options) options = {}; + + const widgets = []; + const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { values: ["fixed", "increment", "decrement", "randomize"], serialize: false, // Don't include this in prompt. }); - valueControl.afterQueued = () => { + widgets.push(valueControl); + const isCombo = targetWidget.type === "combo"; + let comboFilter; + if (isCombo && options.addFilterList !== false) { + comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, { + serialize: false, // Don't include this in prompt. + }); + widgets.push(comboFilter); + } + + valueControl.afterQueued = () => { var v = valueControl.value; - if (targetWidget.type == "combo" && v !== "fixed") { - let current_index = targetWidget.options.values.indexOf(targetWidget.value); - let current_length = targetWidget.options.values.length; + if (isCombo && v !== "fixed") { + let values = targetWidget.options.values; + const filter = comboFilter?.value; + if (filter) { + let check; + if (filter.startsWith("/") && filter.endsWith("/")) { + try { + const regex = new RegExp(filter.substring(1, filter.length - 1)); + check = (item) => regex.test(item); + } catch (error) { + console.error("Error constructing RegExp filter for node " + node.id, filter, error); + } + } + if (!check) { + const lower = filter.toLocaleLowerCase(); + check = (item) => item.toLocaleLowerCase().includes(lower); + } + values = values.filter(item => check(item)); + if (!values.length && targetWidget.options.values.length) { + console.warn("Filter for node " + node.id + " has filtered out all items", filter); + } + } + let current_index = values.indexOf(targetWidget.value); + let current_length = values.length; switch (v) { case "increment": @@ -51,7 +92,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random current_index = Math.max(0, current_index); current_index = Math.min(current_length - 1, current_index); if (current_index >= 0) { - let value = targetWidget.options.values[current_index]; + let value = values[current_index]; targetWidget.value = value; targetWidget.callback(value); } @@ -88,7 +129,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random targetWidget.callback(targetWidget.value); } } - return valueControl; + + return widgets; }; function seedWidget(node, inputName, inputData, app) { From 32447f0c392be6a6b64fbac09fd7e7f33eb451f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 17:23:37 -0500 Subject: [PATCH 14/32] Add sampling_settings so models can specify specific sampling settings. --- comfy/model_sampling.py | 2 +- comfy/supported_models_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 37a3ac725..9e2a1c1af 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -24,7 +24,7 @@ class ModelSamplingDiscrete(torch.nn.Module): super().__init__() beta_schedule = "linear" if model_config is not None: - beta_schedule = model_config.beta_schedule + beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.sigma_data = 1.0 diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index b073eb4fc..3412cfea0 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -19,7 +19,7 @@ class BASE: clip_prefix = [] clip_vision_prefix = None noise_aug_config = None - beta_schedule = "linear" + sampling_settings = {} latent_format = latent_formats.LatentFormat @classmethod From 410bf0777197c7005fe13aa4f6717d6dc63e2b22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 18:16:02 -0500 Subject: [PATCH 15/32] Make VAE memory estimation take dtype into account. --- comfy/sd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c006a0362..a8df3bdd4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -155,8 +155,8 @@ class VAE: if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) - self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking - self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7 + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) if config is None: if "taesd_decoder.1.weight" in sd: @@ -213,7 +213,7 @@ class VAE: def decode(self, samples_in): self.first_stage_model = self.first_stage_model.to(self.device) try: - memory_used = self.memory_used_decode(samples_in.shape) + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) @@ -241,7 +241,7 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) From d03d8aa2e348c6ba3333150eb18aa76f5180a7f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 01:09:15 -0500 Subject: [PATCH 16/32] Fix loading groups. --- web/lib/litegraph.core.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 0ca203842..f571edb30 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font_size = o.font_size; + if (o.font_size) { + this.font_size = o.font_size; + } }; LGraphGroup.prototype.serialize = function() { From 87031a1945278abe6b8a8058dfe6f38a5138655c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 11:59:11 -0500 Subject: [PATCH 17/32] Update readme with link to LCM example page. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d622c9072..f87c0404f 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) +- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything. From a657f96c5cd9d72725352d6b00def82d9ce5d556 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 13:55:29 -0500 Subject: [PATCH 18/32] Add a node to save animated webp. --- comfy_extras/nodes_images.py | 76 ++++++++++++++++++++++++++++++++++++ web/scripts/pnginfo.js | 4 +- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 8cb322327..18c579190 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1,4 +1,12 @@ import nodes +import folder_paths +from comfy.cli_args import args + +from PIL import Image +import numpy as np +import json +import os + MAX_RESOLUTION = nodes.MAX_RESOLUTION class ImageCrop: @@ -38,7 +46,75 @@ class RepeatImageBatch: s = image.repeat((amount, 1,1,1)) return (s,) +class SaveAnimatedWEBP: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + methods = {"default": 4, "fastest": 0, "slowest": 6} + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "lossless": ("BOOLEAN", {"default": True}), + "quality": ("INT", {"default": 80, "min": 0, "max": 100}), + "method": (list(s.methods.keys()),), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): + method = self.methods.get(method, "aoeu") + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = None + if not args.disable_metadata: + metadata = pil_images[0].getexif() + if prompt is not None: + metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) + if extra_pnginfo is not None: + inital_exif = 0x010f + for x in extra_pnginfo: + metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) + inital_exif -= 1 + + if num_frames == 0: + num_frames = len(pil_images) + + c = len(pil_images) + for i in range(0, c, num_frames): + file = f"{filename}_{counter:05}_.webp" + pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + animated = num_frames != 1 + return { "ui": { "images": results, "animated": (animated,) } } + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "SaveAnimatedWEBP": SaveAnimatedWEBP, } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 491caed79..f8cbe7a3c 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -50,7 +50,6 @@ export function getPngMetadata(file) { function parseExifData(exifData) { // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; - console.log(exifData); // Function to read 16-bit and 32-bit integers from binary data function readInt(offset, isLittleEndian, length) { @@ -126,6 +125,9 @@ export function getWebpMetadata(file) { const chunk_length = dataView.getUint32(offset + 4, true); const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4)); if (chunk_type === "EXIF") { + if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") { + offset += 6; + } let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length)); for (var key in data) { var value = data[key]; From 4d2437e68165cf12989dafe1ef0a26c3a0abc7f5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 23 Nov 2023 19:43:55 +0000 Subject: [PATCH 19/32] Call widget onRemove to remove element --- web/scripts/app.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 601e486e6..180416ef9 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -544,6 +544,7 @@ export class ComfyApp { } if (widgetIdx > -1) { + this.widgets[widgetIdx].onRemove?.(); this.widgets.splice(widgetIdx, 1); } From 022033a0e75901c7c357ab96e1c804fd5da05770 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 15:06:35 -0500 Subject: [PATCH 20/32] Fix SaveAnimatedWEBP not working when metadata is disabled. --- comfy_extras/nodes_images.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 18c579190..8c6ae5387 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -75,7 +75,7 @@ class SaveAnimatedWEBP: CATEGORY = "_for_testing" def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method, "aoeu") + method = self.methods.get(method) filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() @@ -85,9 +85,8 @@ class SaveAnimatedWEBP: img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) pil_images.append(img) - metadata = None + metadata = pil_images[0].getexif() if not args.disable_metadata: - metadata = pil_images[0].getexif() if prompt is not None: metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) if extra_pnginfo is not None: From 871cc20e13e9ef2629e3b5faa6af64207e86d6d2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 19:41:33 -0500 Subject: [PATCH 21/32] Support SVD img2vid model. --- comfy/cldm/cldm.py | 1 + comfy/ldm/modules/attention.py | 271 ++++++++++++-- .../modules/diffusionmodules/openaimodel.py | 348 +++++++++++++++--- comfy/ldm/modules/diffusionmodules/util.py | 69 +++- comfy/ldm/modules/temporal_ae.py | 244 ++++++++++++ comfy/model_base.py | 56 ++- comfy/model_detection.py | 18 +- comfy/model_sampling.py | 46 ++- comfy/sd.py | 10 +- comfy/supported_models.py | 36 +- comfy_extras/nodes_model_advanced.py | 31 ++ 11 files changed, 1030 insertions(+), 100 deletions(-) create mode 100644 comfy/ldm/modules/temporal_ae.py diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 9a63202ab..76a525b37 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -54,6 +54,7 @@ class ControlNet(nn.Module): transformer_depth_output=None, device=None, operations=comfy.ops, + **kwargs, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 016795a59..947e2008c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,10 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any +from functools import partial -from .diffusionmodules.util import checkpoint + +from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -370,21 +372,45 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): super().__init__() + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + self.disable_self_attn = disable_self_attn - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + context_dim_attn2 = None + if not switch_temporal_ca_to_sa: + context_dim_attn2 = context_dim + + self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + + self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) @@ -418,6 +444,12 @@ class BasicTransformerBlock(nn.Module): else: transformer_patches_replace = {} + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + n = self.norm1(x) if self.disable_self_attn: context_attn1 = context @@ -465,31 +497,34 @@ class BasicTransformerBlock(nn.Module): for p in patch: x = p(x, extra_options) - n = self.norm2(x) - - context_attn2 = context - value_attn2 = None - if "attn2_patch" in transformer_patches: - patch = transformer_patches["attn2_patch"] - value_attn2 = context_attn2 - for p in patch: - n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) - - attn2_replace_patch = transformer_patches_replace.get("attn2", {}) - block_attn2 = transformer_block - if block_attn2 not in attn2_replace_patch: - block_attn2 = block - - if block_attn2 in attn2_replace_patch: - if value_attn2 is None: + if self.attn2 is not None: + n = self.norm2(x) + if self.switch_temporal_ca_to_sa: + context_attn2 = n + else: + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 - n = self.attn2.to_q(n) - context_attn2 = self.attn2.to_k(context_attn2) - value_attn2 = self.attn2.to_v(value_attn2) - n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) - n = self.attn2.to_out(n) - else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + for p in patch: + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) + + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module): n = p(n, extra_options) x += n - x = self.ff(self.norm3(x)) + x + if self.is_res: + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + return x @@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module): x = self.proj_out(x) return x + x_in + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype=None, device=None, operations=comfy.ops + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + dtype=dtype, device=device, operations=operations + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + # timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + dtype=dtype, device=device, operations=operations + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device), + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + transformer_options={} + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + if time_context is None: + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_stack) + ): + transformer_options["block_index"] = it_ + x = block( + x, + context=spatial_context, + transformer_options=transformer_options, + ) + + x_mix = x + x_mix = x_mix + emb + + B, S, C = x_mix.shape + x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) + x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = rearrange( + x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + + x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out + + diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index e8f35a540..a497ed344 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -5,6 +5,8 @@ import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F +from einops import rearrange +from functools import partial from .util import ( checkpoint, @@ -12,8 +14,9 @@ from .util import ( zero_module, normalization, timestep_embedding, + AlphaBlender, ) -from ..attention import SpatialTransformer +from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops @@ -29,10 +32,15 @@ class TimestepBlock(nn.Module): """ #This is needed because accelerate makes a copy of transformer_options which breaks "current_index" -def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: - if isinstance(layer, TimestepBlock): + if isinstance(layer, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(layer, TimestepBlock): x = layer(x, emb) + elif isinstance(layer, SpatialVideoTransformer): + x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) + transformer_options["current_index"] += 1 elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) if "current_index" in transformer_options: @@ -145,6 +153,9 @@ class ResBlock(TimestepBlock): use_checkpoint=False, up=False, down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, dtype=None, device=None, operations=comfy.ops @@ -157,11 +168,17 @@ class ResBlock(TimestepBlock): self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, list): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 self.in_layers = nn.Sequential( nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), + operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) self.updown = up or down @@ -175,19 +192,24 @@ class ResBlock(TimestepBlock): else: self.h_upd = self.x_upd = nn.Identity() - self.emb_layers = nn.Sequential( - nn.SiLU(), - operations.Linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device - ), - ) + self.skip_t_emb = skip_t_emb + if self.skip_t_emb: + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + operations.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device + ), + ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) ), ) @@ -195,7 +217,7 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = operations.conv_nd( - dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device + dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device ) else: self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) @@ -221,19 +243,110 @@ class ResBlock(TimestepBlock): h = in_conv(h) else: h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] + + emb_out = None + if not self.skip_t_emb: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift + h = out_norm(h) + if emb_out is not None: + scale, shift = th.chunk(emb_out, 2, dim=1) + h *= (1 + scale) + h += shift h = out_rest(h) else: - h = h + emb_out + if emb_out is not None: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h + +class VideoResBlock(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + video_kernel_size=3, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + out_channels=None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + dtype=None, + device=None, + operations=comfy.ops + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + dtype=dtype, + device=device, + operations=operations + ) + + self.time_stack = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + dtype=dtype, + device=device, + operations=operations + ) + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_stack( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + class Timestep(nn.Module): def __init__(self, dim): super().__init__() @@ -310,6 +423,16 @@ class UNetModel(nn.Module): adm_in_channels=None, transformer_depth_middle=None, transformer_depth_output=None, + use_temporal_resblock=False, + use_temporal_attention=False, + time_context_dim=None, + extra_ff_mix_layer=False, + use_spatial_context=False, + merge_strategy=None, + merge_factor=0.0, + video_kernel_size=None, + disable_temporal_crossattention=False, + max_ddpm_temb_period=10000, device=None, operations=comfy.ops, ): @@ -364,8 +487,12 @@ class UNetModel(nn.Module): self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample + self.use_temporal_resblocks = use_temporal_resblock self.predict_codebook_ids = n_embed is not None + self.default_num_video_frames = None + self.default_image_only_indicator = None + time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), @@ -402,13 +529,104 @@ class UNetModel(nn.Module): input_block_chans = [model_channels] ch = model_channels ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disable_self_attn=False, + ): + if use_temporal_attention: + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + dtype=self.dtype, device=device, operations=operations + ) + else: + return SpatialTransformer( + ch, num_heads, dim_head, depth=depth, context_dim=context_dim, + disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_channels, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + dtype=None, + device=None, + operations=comfy.ops + ): + if self.use_temporal_resblocks: + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + else: + return ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + use_checkpoint=use_checkpoint, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, @@ -435,11 +653,9 @@ class UNetModel(nn.Module): disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append(SpatialTransformer( + layers.append(get_attention_layer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ) + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch @@ -448,10 +664,13 @@ class UNetModel(nn.Module): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -481,10 +700,14 @@ class UNetModel(nn.Module): #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels mid_block = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, @@ -493,15 +716,18 @@ class UNetModel(nn.Module): operations=operations )] if transformer_depth_middle >= 0: - mid_block += [SpatialTransformer( # always uses a self-attn + mid_block += [get_attention_layer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint ), - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, @@ -517,10 +743,13 @@ class UNetModel(nn.Module): for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, @@ -548,19 +777,21 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( - SpatialTransformer( + get_attention_layer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -602,6 +833,10 @@ class UNetModel(nn.Module): transformer_options["current_index"] = 0 transformer_patches = transformer_options.get("patches", {}) + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + time_context = kwargs.get("time_context", None) + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -616,7 +851,7 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) - h = forward_timestep_embed(module, h, emb, context, transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'input') if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] @@ -630,9 +865,10 @@ class UNetModel(nn.Module): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() @@ -649,7 +885,7 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None - h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 0298ca99d..704bbe574 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -13,11 +13,78 @@ import math import torch import torch.nn as nn import numpy as np -from einops import repeat +from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config import comfy.ops +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert ( + merge_strategy in self.strategies + ), f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif ( + self.merge_strategy == "learned" + or self.merge_strategy == "learned_with_images" + ): + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) + if self.merge_strategy == "fixed": + # make shape compatible + # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + else: + raise NotImplementedError() + return alpha + + def forward( + self, + x_spatial, + x_temporal, + image_only_indicator=None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = ( + alpha.to(x_spatial.dtype) * x_spatial + + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + ) + return x + + def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": betas = ( diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py new file mode 100644 index 000000000..11ae049f3 --- /dev/null +++ b/comfy/ldm/modules/temporal_ae.py @@ -0,0 +1,244 @@ +import functools +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +import comfy.ops + +from .diffusionmodules.model import ( + AttnBlock, + Decoder, + ResnetBlock, +) +from .diffusionmodules.openaimodel import ResBlock, timestep_embedding +from .attention import BasicTransformerBlock + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + b, c, h, w = x.shape + if timesteps is None: + timesteps = b + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps=None, skip_video=False): + if timesteps is None: + timesteps = input.shape[0] + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class AttnVideoBlock(AttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = BasicTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + comfy.ops.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + comfy.ops.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps=None, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + if timesteps is None: + timesteps = x.shape[0] + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + return partialclass( + AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy + ) + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + + if self.time_mode != "attn-only": + kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + if self.time_mode not in ["conv-only", "only-last-conv"]: + kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy) + if self.time_mode not in ["attn-only", "only-last-conv"]: + kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy) + + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return ( + self.conv_out.time_mix_conv.weight + if not skip_time_mix + else self.conv_out.weight + ) diff --git a/comfy/model_base.py b/comfy/model_base.py index 772e26934..34274c4ae 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -10,17 +10,22 @@ from . import utils class ModelType(Enum): EPS = 1 V_PREDICTION = 2 + V_PREDICTION_EDM = 3 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM + def model_sampling(model_config, model_type): + s = ModelSamplingDiscrete + if model_type == ModelType.EPS: c = EPS elif model_type == ModelType.V_PREDICTION: c = V_PREDICTION - - s = ModelSamplingDiscrete + elif model_type == ModelType.V_PREDICTION_EDM: + c = V_PREDICTION + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass @@ -262,3 +267,48 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + +class SVD_img2vid(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder = Timestep(256) + + def encode_adm(self, **kwargs): + fps_id = kwargs.get("fps", 6) - 1 + motion_bucket_id = kwargs.get("motion_bucket_id", 127) + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.Tensor([fps_id]))) + out.append(self.embedder(torch.Tensor([motion_bucket_id]))) + out.append(self.embedder(torch.Tensor([augmentation]))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + + def extra_conds(self, **kwargs): + out = {} + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + + if "time_conditioning" in kwargs: + out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) + + out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) + out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d65d91e7c..45d603a0c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 - return last_transformer_depth, context_dim, use_linear_in_transformer + time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict + return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None def detect_unet_config(state_dict, key_prefix, dtype): @@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): context_dim = None use_linear_in_transformer = False + video_model = False current_res = 1 count = 0 @@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): if context_dim is None: context_dim = out[1] use_linear_in_transformer = out[2] + video_model = out[3] else: transformer_depth.append(0) @@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer unet_config["context_dim"] = context_dim + + if video_model: + unet_config["extra_ff_mix_layer"] = True + unet_config["use_spatial_context"] = True + unet_config["merge_strategy"] = "learned_with_images" + unet_config["merge_factor"] = 0.0 + unet_config["video_kernel_size"] = [3, 1, 1] + unet_config["use_temporal_resblock"] = True + unet_config["use_temporal_attention"] = True + else: + unet_config["use_temporal_resblock"] = False + unet_config["use_temporal_attention"] = False + return unet_config def model_config_from_unet_config(unet_config): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 9e2a1c1af..fac5c995e 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,7 +1,7 @@ import torch import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule - +import math class EPS: def calculate_input(self, sigma, noise): @@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module): percent = 1.0 - percent return self.sigma(torch.tensor(percent * 999.0)).item() + +class ModelSamplingContinuousEDM(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + self.sigma_data = 1.0 + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + sigma_min = sampling_settings.get("sigma_min", 0.002) + sigma_max = sampling_settings.get("sigma_max", 120.0) + self.set_sigma_range(sigma_min, sigma_max) + + def set_sigma_range(self, sigma_min, sigma_max): + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() + + self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return 0.25 * sigma.log() + + def sigma(self, timestep): + return (timestep / 0.25).exp() + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + + log_sigma_min = math.log(self.sigma_min) + return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) diff --git a/comfy/sd.py b/comfy/sd.py index a8df3bdd4..7f85540c4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -159,7 +159,15 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) if config is None: - if "taesd_decoder.1.weight" in sd: + if "decoder.mid.block_1.mix_factor" in sd: + encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + decoder_config = encoder_config.copy() + decoder_config["video_kernel_size"] = [3, 1, 1] + decoder_config["alpha"] = 0.0 + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, + decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) + elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() else: #default SD1.x/SD2.x VAE parameters diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fdd4ea4f5..7e2ac677d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": False, "adm_in_channels": None, + "use_temporal_attention": False, } unet_extra_config = { @@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": None, + "use_temporal_attention": False, } latent_format = latent_formats.SD15 @@ -88,6 +90,7 @@ class SD21UnclipL(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 1536, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -100,6 +103,7 @@ class SD21UnclipH(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 2048, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE): "context_dim": 1280, "adm_in_channels": 2560, "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE): "use_linear_in_transformer": True, "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -203,8 +209,34 @@ class SSD1B(SDXL): "use_linear_in_transformer": True, "transformer_depth": [0, 0, 2, 2, 4, 4], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } +class SVD_img2vid(supported_models_base.BASE): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 768, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." + + latent_format = latent_formats.SD15 + + sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002} + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SVD_img2vid(self, device=device) + return out + + def clip_target(self): + return None models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +models += [SVD_img2vid] diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 0f4ddd9c3..6991c9837 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -128,6 +128,36 @@ class ModelSamplingDiscrete: m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingContinuousEDM: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction", "eps"],), + "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + model_sampling.set_sigma_range(sigma_min, sigma_max) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class RescaleCFG: @classmethod def INPUT_TYPES(s): @@ -169,5 +199,6 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, + "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, "RescaleCFG": RescaleCFG, } From 42dfae63312f443d13841a0c4a5de467f5c354c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 19:43:09 -0500 Subject: [PATCH 22/32] Nodes to properly use the SDV img2vid checkpoint. The img2vid model is conditioned on clip vision output only which means there's no CLIP model which is why I added a ImageOnlyCheckpointLoader to load it. Note that the unClipCheckpointLoader can also load it because it also has a CLIP_VISION output. SDV_img2vid_Conditioning is the node used to pass the right conditioning to the img2vid model. VideoLinearCFGGuidance applies a linearly decreasing CFG scale to each video frame from the cfg set in the sampler node to min_cfg. SDV_img2vid_Conditioning can be found in conditioning->video_models ImageOnlyCheckpointLoader can be found in loaders->video_models VideoLinearCFGGuidance can be found in sampling->video_models --- comfy_extras/nodes_video_model.py | 89 +++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 90 insertions(+) create mode 100644 comfy_extras/nodes_video_model.py diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py new file mode 100644 index 000000000..92bd883ae --- /dev/null +++ b/comfy_extras/nodes_video_model.py @@ -0,0 +1,89 @@ +import nodes +import torch +import comfy.utils +import comfy.sd +import folder_paths + + +class ImageOnlyCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }} + RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders/video_models" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (out[0], out[3], out[2]) + + +class SDV_img2vid_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), + "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), + "fps": ("INT", {"default": 6, "min": 1, "max": 1024}), + "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + if augmentation_level > 0: + encode_pixels += torch.randn_like(pixels) * augmentation_level + t = vae.encode(encode_pixels) + positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +class VideoLinearCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1)) + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, + "SDV_img2vid_Conditioning": SDV_img2vid_Conditioning, + "VideoLinearCFGGuidance": VideoLinearCFGGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", +} diff --git a/nodes.py b/nodes.py index 2de468da7..bb24bc6e8 100644 --- a/nodes.py +++ b/nodes.py @@ -1850,6 +1850,7 @@ def init_custom_nodes(): "nodes_model_advanced.py", "nodes_model_downscale.py", "nodes_images.py", + "nodes_video_model.py", ] for node_file in extras_files: From 02ffbb2de3e33d9d64d38c13e70e860d9af90101 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 23:20:07 -0500 Subject: [PATCH 23/32] Fix typo. --- comfy_extras/nodes_video_model.py | 4 ++-- web/scripts/app.js | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 92bd883ae..26a717a38 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -21,7 +21,7 @@ class ImageOnlyCheckpointLoader: return (out[0], out[3], out[2]) -class SDV_img2vid_Conditioning: +class SVD_img2vid_Conditioning: @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), @@ -80,7 +80,7 @@ class VideoLinearCFGGuidance: NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, - "SDV_img2vid_Conditioning": SDV_img2vid_Conditioning, + "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, } diff --git a/web/scripts/app.js b/web/scripts/app.js index 180416ef9..cd20c40fd 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1523,6 +1523,7 @@ export class ComfyApp { // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix + if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { From c782cf3ea95021b0d9fa95014b13e7c32f20fd6e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 00:27:08 -0500 Subject: [PATCH 24/32] Add to Readme that Stable Video Diffusion is supported. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f87c0404f..9d7e31790 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x and SDXL +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) From 982338b9bb41301000ddac46d67103af9d0582cd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 02:08:08 -0500 Subject: [PATCH 25/32] Fix issue loading webp files in UI. --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 6f01aa5b2..8a58d30b3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -599,7 +599,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent,.safetensors", + accept: ".json,image/png,.latent,.safetensors,image/webp", style: {display: "none"}, parent: document.body, onchange: () => { From 3e5ea74ad356e849ea27f1d766a7b6d90a5acfda Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 03:55:35 -0500 Subject: [PATCH 26/32] Make buggy xformers fall back on pytorch attention. --- comfy/ldm/modules/attention.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 947e2008c..d511dda16 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -278,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None): ) return r1 +BROKEN_XFORMERS = False +try: + x_vers = xformers.__version__ + #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error) + BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23") +except: + pass + def attention_xformers(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads + if BROKEN_XFORMERS: + if b * heads > 65535: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( lambda t: t.unsqueeze(3) From eff24ea6aa4f53870f575ec34371b7db940c1cfc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 11:12:10 -0500 Subject: [PATCH 27/32] Add a node to save animated PNG files. These work in ffpmeg unlike webp. --- comfy_extras/nodes_images.py | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 8c6ae5387..450c8dc40 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -3,6 +3,8 @@ import folder_paths from comfy.cli_args import args from PIL import Image +from PIL.PngImagePlugin import PngInfo + import numpy as np import json import os @@ -112,8 +114,62 @@ class SaveAnimatedWEBP: animated = num_frames != 1 return { "ui": { "images": results, "animated": (animated,) } } +class SaveAnimatedPNG: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 12.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + file = f"{filename}_{counter:05}_.png" + pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + + return { "ui": { "images": results, "animated": (True,)} } + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, + "SaveAnimatedPNG": SaveAnimatedPNG, } From 916e9c998c5952a30e7795ccfda74186a82a2a06 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 11:19:23 -0500 Subject: [PATCH 28/32] Use same default fps as webp node. --- comfy_extras/nodes_images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 450c8dc40..4c86b2df6 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -125,7 +125,7 @@ class SaveAnimatedPNG: return {"required": {"images": ("IMAGE", ), "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 12.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, From 8ad5d494d52883e02f5745603dfd06f1a49c040b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 18:14:17 -0500 Subject: [PATCH 29/32] Fix APNG not working in ffmpeg. --- comfy_extras/nodes_images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4c86b2df6..4b6cd3d1b 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -152,10 +152,10 @@ class SaveAnimatedPNG: if not args.disable_metadata: metadata = PngInfo() if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) + metadata.add(b"tEXt", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) if extra_pnginfo is not None: for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + metadata.add(b"tEXt", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) file = f"{filename}_{counter:05}_.png" pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) From e020ab61f97fd8bccc31e7eebd23acd5dd9e2ecd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 18:24:19 -0500 Subject: [PATCH 30/32] Fix output APNG not working with ffmpeg. --- comfy_extras/nodes_images.py | 4 ++-- web/scripts/pnginfo.js | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4b6cd3d1b..5ad2235a5 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -152,10 +152,10 @@ class SaveAnimatedPNG: if not args.disable_metadata: metadata = PngInfo() if prompt is not None: - metadata.add(b"tEXt", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) + metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) if extra_pnginfo is not None: for x in extra_pnginfo: - metadata.add(b"tEXt", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) + metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) file = f"{filename}_{counter:05}_.png" pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index f8cbe7a3c..83a4ebc86 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -24,7 +24,7 @@ export function getPngMetadata(file) { const length = dataView.getUint32(offset); // Get the chunk type const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); - if (type === "tEXt") { + if (type === "tEXt" || type == "comf") { // Get the keyword let keyword_end = offset + 8; while (pngData[keyword_end] !== 0) { From 5d6dfce5481f67bcfb30b1b39ad6eb78022653af Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 20:35:29 -0500 Subject: [PATCH 31/32] Fix importing diffusers unets. --- comfy/model_detection.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 45d603a0c..c682c3e1a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -232,52 +232,62 @@ def unet_config_from_diffusers_unet(state_dict, dtype): SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, - 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, - 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0]} + 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, - 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, - 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, - 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, - 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, - 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1]} + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, - 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0]} + 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, - 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], - 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64} + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] From 5b37270d3ad2227a30e15101a8d528ca77bd589d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 25 Nov 2023 02:26:50 -0500 Subject: [PATCH 32/32] Add a lora loader node for models with no CLIP. --- nodes.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nodes.py b/nodes.py index bb24bc6e8..df40f8094 100644 --- a/nodes.py +++ b/nodes.py @@ -572,6 +572,19 @@ class LoraLoader: model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) +class LoraLoaderModelOnly(LoraLoader): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: @staticmethod def vae_list(): @@ -1703,6 +1716,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningZeroOut": ConditioningZeroOut, "ConditioningSetTimestepRange": ConditioningSetTimestepRange, + "LoraLoaderModelOnly": LoraLoaderModelOnly, } NODE_DISPLAY_NAME_MAPPINGS = {