From ae08fdb9990956f671d658aaf72a1eaf982b5b33 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 9 May 2023 03:37:36 +0900 Subject: [PATCH 001/120] Clipspace Menu and MaskEditor application. (#548) * Add clipspace feature. * feat: copy content to clipspace * feat: paste content from clipspace Extend validation to allow for validating annotated_path in addition to other parameters. Add support for annotated_filepath in folder_paths function. Generalize the '/upload/image' API to allow for uploading images to the 'input', 'temp', or 'output' directories. * rename contentClipboard -> clipspace * Do deep copy for imgs on copy to clipspace. * mask painting on clipspace * add original_imgs into clipspace * Preserve the original image when 'imgs' are modified * robust patch & refactoring folder_paths about annotated_filepath * wip * Only show the Paste menu if the ComfyApp.clipspace is not empty * clipspace feature added maskeditor feature added * instant refresh on paste force triggering 'changed' on paste action * enhance mask painting smooth drawing add brush_size +/- button * robust patch use mouseup event * robust patch again... * subfolder fix on paste logic attach subfolder if subfolder isn't empty * event listener patch add ], [ key event for brush size remove listener on close * Fix button positioning issue related to window height. Change brush size from button to slider. * clean commit * clean code * various bug fixes * paste action - prevent opening upload popup - ensure rendering after widget_value update * view api update - support annotated_filepath * maskeditor layout - prevent covering button by hidden div * remove dbg message * Add cursor functionality to display brush size * refactor: Replace brush preview feature with missionfloyd implementation * missionfloyd implementation * hiding brush preview off the canvas * change brush size on wheel event * keyup -> keydown event * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Add support for channel-specific image data retrieval in /view API to fix alpha mask loading issue When loading an image with an alpha mask in JavaScript canvas, there is an issue where the alpha and RGB channels are premultiplied. To avoid reliance on JavaScript canvas, I added support for channel-specific image data retrieval in the "/view" API. This allows us to retrieve data for each channel separately and fix the alpha mask loading issue. The changes have been committed to the repository. * Enable brush preview for key and slider events * optimize * preview fix * robust patch * fix copy (clipspace) action imgs[0] copy -> whole imgs copy * support batch images on clipspace, maskeditor * copy/paste bug fixes for batch images enhance selector preview on clipspace menu add img_paste_mode option into clipspace menu * crash fix * print message if clipspace content cannot editable * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * make default img_paste_mode to 'selected' refactor space -> tab * save clipspace files to input/clipspace instead of temp * show "clipspace/filename.png" instead of 'filename.png [clipspace]' in LoadImage/LoadImageMask * refresh fix related to FILE_COMBO * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * adjust margin based on missionfloyd impelements * mouse event -> pointer event * pen, touch, mouse drawing patched and tested * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * add comment about touch event. --------- Co-authored-by: Lt.Dr.Data Co-authored-by: missionfloyd --- folder_paths.py | 9 + nodes.py | 8 +- server.py | 122 ++++++- web/extensions/core/clipspace.js | 166 +++++++++ web/extensions/core/maskeditor.js | 589 ++++++++++++++++++++++++++++++ web/scripts/app.js | 114 ++++-- web/scripts/ui.js | 1 + web/scripts/widgets.js | 14 + 8 files changed, 976 insertions(+), 47 deletions(-) create mode 100644 web/extensions/core/clipspace.js create mode 100644 web/extensions/core/maskeditor.js diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..0acd22674 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,6 +57,10 @@ def get_input_directory(): global input_directory return input_directory +def get_clipspace_directory(): + global input_directory + return input_directory+"/clipspace" + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -66,6 +70,8 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() + if type_name == "clipspace": + return get_clipspace_directory() return None @@ -81,6 +87,9 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] + elif name.endswith("[clipspace]"): + base_dir = get_clipspace_directory() + name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index ca0769ba7..1d9a5c872 100644 --- a/nodes.py +++ b/nodes.py @@ -973,8 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), )}, + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, } CATEGORY = "image" @@ -1014,9 +1015,10 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), ), - "channel": (s._color_channels, ),} + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + "channel": (s._color_channels, ), } } CATEGORY = "mask" diff --git a/server.py b/server.py index 1c5c17916..48644d83a 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,9 @@ import execution import uuid import json import glob +from PIL import Image +from io import BytesIO + try: import aiohttp from aiohttp import web @@ -110,19 +113,26 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) + def get_dir_by_type(dir_type): + if dir_type is None: + type_dir = folder_paths.get_input_directory() + elif dir_type == "input": + type_dir = folder_paths.get_input_directory() + elif dir_type == "clipspace": + type_dir = folder_paths.get_clipspace_directory() + elif dir_type == "temp": + type_dir = folder_paths.get_temp_directory() + elif dir_type == "output": + type_dir = folder_paths.get_output_directory() + + return type_dir + @routes.post("/upload/image") async def upload_image(request): post = await request.post() image = post.get("image") - if post.get("type") is None: - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "input": - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "temp": - upload_dir = folder_paths.get_temp_directory() - elif post.get("type") == "output": - upload_dir = folder_paths.get_output_directory() + upload_dir = get_dir_by_type(post.get("type")) if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -147,12 +157,62 @@ class PromptServer(): else: return web.Response(status=400) + @routes.post("/upload/mask") + async def upload_mask(request): + post = await request.post() + image = post.get("image") + original_image = post.get("original_image") + + upload_dir = get_dir_by_type(post.get("type")) + + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + + if image and image.file: + filename = image.filename + if not filename: + return web.Response(status=400) + + split = os.path.splitext(filename) + i = 1 + while os.path.exists(os.path.join(upload_dir, filename)): + filename = f"{split[0]} ({i}){split[1]}" + i += 1 + + filepath = os.path.join(upload_dir, filename) + + original_pil = Image.open(original_image.file).convert('RGBA') + mask_pil = Image.open(image.file).convert('RGBA') + + # alpha copy + new_alpha = mask_pil.getchannel('A') + original_pil.putalpha(new_alpha) + + original_pil.save(filepath) + + return web.json_response({"name": filename}) + else: + return web.Response(status=400) + @routes.get("/view") async def view_image(request): if "filename" in request.rel_url.query: - type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + filename = request.rel_url.query["filename"] + filename,output_dir = folder_paths.annotated_filepath(filename) + + if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): + output_dir = folder_paths.get_clipspace_directory() + filename = filename[10:] + + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) + + if output_dir is None: + type = request.rel_url.query.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) @@ -162,13 +222,49 @@ class PromptServer(): return web.Response(status=403) output_dir = full_output_dir - filename = request.rel_url.query["filename"] filename = os.path.basename(filename) file = os.path.join(output_dir, filename) if os.path.isfile(file): - return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) - + if 'channel' not in request.rel_url.query: + channel = 'rgba' + else: + channel = request.rel_url.query["channel"] + + if channel == 'rgb': + with Image.open(file) as img: + if img.mode == "RGBA": + r, g, b, a = img.split() + new_img = Image.merge('RGB', (r, g, b)) + else: + new_img = img.convert("RGB") + + buffer = BytesIO() + new_img.save(buffer, format='PNG') + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + + elif channel == 'a': + with Image.open(file) as img: + if img.mode == "RGBA": + _, _, _, a = img.split() + else: + a = Image.new('L', img.size, 255) + + # alpha img + alpha_img = Image.new('RGBA', img.size) + alpha_img.putalpha(a) + alpha_buffer = BytesIO() + alpha_img.save(alpha_buffer, format='PNG') + alpha_buffer.seek(0) + + return web.Response(body=alpha_buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + else: + return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) + return web.Response(status=404) @routes.get("/prompt") diff --git a/web/extensions/core/clipspace.js b/web/extensions/core/clipspace.js new file mode 100644 index 000000000..adb5877ea --- /dev/null +++ b/web/extensions/core/clipspace.js @@ -0,0 +1,166 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; + +export class ClipspaceDialog extends ComfyDialog { + static items = []; + static instance = null; + + static registerButton(name, contextPredicate, callback) { + const item = + $el("button", { + type: "button", + textContent: name, + contextPredicate: contextPredicate, + onclick: callback + }) + + ClipspaceDialog.items.push(item); + } + + static invalidatePreview() { + if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) { + const img_preview = document.getElementById("clipspace_preview"); + if(img_preview) { + img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + img_preview.style.maxHeight = "100%"; + img_preview.style.maxWidth = "100%"; + } + } + } + + static invalidate() { + if(ClipspaceDialog.instance) { + const self = ClipspaceDialog.instance; + // allow reconstruct controls when copying from non-image to image content. + const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]); + + if(self.element) { + // update + self.element.removeChild(self.element.firstChild); + self.element.appendChild(children); + } + else { + // new + self.element = $el("div.comfy-modal", { parent: document.body }, [children,]); + } + + if(self.element.children[0].children.length <= 1) { + self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."])); + } + + ClipspaceDialog.invalidatePreview(); + } + } + + constructor() { + super(); + } + + createButtons(self) { + const buttons = []; + + for(let idx in ClipspaceDialog.items) { + const item = ClipspaceDialog.items[idx]; + if(!item.contextPredicate || item.contextPredicate()) + buttons.push(ClipspaceDialog.items[idx]); + } + + buttons.push( + $el("button", { + type: "button", + textContent: "Close", + onclick: () => { this.close(); } + }) + ); + + return buttons; + } + + createImgSettings() { + if(ComfyApp.clipspace.imgs) { + const combo_items = []; + const imgs = ComfyApp.clipspace.imgs; + + for(let i=0; i < imgs.length; i++) { + combo_items.push($el("option", {value:i}, [`${i}`])); + } + + const combo1 = $el("select", + {id:"clipspace_img_selector", onchange:(event) => { + ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex; + ClipspaceDialog.invalidatePreview(); + } }, combo_items); + + const row1 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]), + $el("td", {}, [combo1]) + ]); + + + const combo2 = $el("select", + {id:"clipspace_img_paste_mode", onchange:(event) => { + ComfyApp.clipspace['img_paste_mode'] = event.target.value; + } }, + [ + $el("option", {value:'selected'}, 'selected'), + $el("option", {value:'all'}, 'all') + ]); + combo2.value = ComfyApp.clipspace['img_paste_mode']; + + const row2 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]), + $el("td", {}, [combo2]) + ]); + + const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'}, + [ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]); + + const row3 = + $el("tr", {}, [td]); + + return $el("table", {}, [row1, row2, row3]); + } + else { + return []; + } + } + + createImgPreview() { + if(ComfyApp.clipspace.imgs) { + return $el("img",{id:"clipspace_preview", ondragstart:() => false}); + } + else + return []; + } + + show() { + const img_preview = document.getElementById("clipspace_preview"); + ClipspaceDialog.invalidate(); + + this.element.style.display = "block"; + } +} + +app.registerExtension({ + name: "Comfy.Clipspace", + init(app) { + app.openClipspace = + function () { + if(!ClipspaceDialog.instance) { + ClipspaceDialog.instance = new ClipspaceDialog(app); + ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate; + } + + if(ComfyApp.clipspace) { + ClipspaceDialog.instance.show(); + } + else + app.ui.dialog.show("Clipspace is Empty!"); + }; + } +}); \ No newline at end of file diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js new file mode 100644 index 000000000..c55f841b6 --- /dev/null +++ b/web/extensions/core/maskeditor.js @@ -0,0 +1,589 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; +import { ClipspaceDialog } from "/extensions/core/clipspace.js"; + +// Helper function to convert a data URL to a Blob object +function dataURLToBlob(dataURL) { + const parts = dataURL.split(';base64,'); + const contentType = parts[0].split(':')[1]; + const byteString = atob(parts[1]); + const arrayBuffer = new ArrayBuffer(byteString.length); + const uint8Array = new Uint8Array(arrayBuffer); + for (let i = 0; i < byteString.length; i++) { + uint8Array[i] = byteString.charCodeAt(i); + } + return new Blob([arrayBuffer], { type: contentType }); +} + +function loadedImageToBlob(image) { + const canvas = document.createElement('canvas'); + + canvas.width = image.width; + canvas.height = image.height; + + const ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0); + + const dataURL = canvas.toDataURL('image/png', 1); + const blob = dataURLToBlob(dataURL); + + return blob; +} + +async function uploadMask(filepath, formData) { + await fetch('/upload/mask', { + method: 'POST', + body: formData + }).then(response => {}).catch(error => { + console.error('Error:', error); + }); + + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; + + ClipspaceDialog.invalidatePreview(); +} + +function prepareRGB(image, backupCanvas, backupCtx) { + // paste mask data into alpha channel + backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); +} + +class MaskEditorDialog extends ComfyDialog { + static instance = null; + constructor() { + super(); + this.element = $el("div.comfy-modal", { parent: document.body }, + [ $el("div.comfy-modal-content", + [...this.createButtons()]), + ]); + MaskEditorDialog.instance = this; + } + + createButtons() { + return []; + } + + clearMask(self) { + } + + createButton(name, callback) { + var button = document.createElement("button"); + button.innerText = name; + button.addEventListener("click", callback); + return button; + } + createLeftButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "left"; + button.style.marginRight = "4px"; + return button; + } + createRightButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "right"; + button.style.marginLeft = "4px"; + return button; + } + createLeftSlider(self, name, callback) { + const divElement = document.createElement('div'); + divElement.id = "maskeditor-slider"; + divElement.style.cssFloat = "left"; + divElement.style.fontFamily = "sans-serif"; + divElement.style.marginRight = "4px"; + divElement.style.color = "var(--input-text)"; + divElement.style.backgroundColor = "var(--comfy-input-bg)"; + divElement.style.borderRadius = "8px"; + divElement.style.borderColor = "var(--border-color)"; + divElement.style.borderStyle = "solid"; + divElement.style.fontSize = "15px"; + divElement.style.height = "21px"; + divElement.style.padding = "1px 6px"; + divElement.style.display = "flex"; + divElement.style.position = "relative"; + divElement.style.top = "2px"; + self.brush_slider_input = document.createElement('input'); + self.brush_slider_input.setAttribute('type', 'range'); + self.brush_slider_input.setAttribute('min', '1'); + self.brush_slider_input.setAttribute('max', '100'); + self.brush_slider_input.setAttribute('value', '10'); + const labelElement = document.createElement("label"); + labelElement.textContent = name; + + divElement.appendChild(labelElement); + divElement.appendChild(self.brush_slider_input); + + self.brush_slider_input.addEventListener("change", callback); + + return divElement; + } + + setlayout(imgCanvas, maskCanvas) { + const self = this; + + // If it is specified as relative, using it only as a hidden placeholder for padding is recommended + // to prevent anomalies where it exceeds a certain size and goes outside of the window. + var placeholder = document.createElement("div"); + placeholder.style.position = "relative"; + placeholder.style.height = "50px"; + + var bottom_panel = document.createElement("div"); + bottom_panel.style.position = "absolute"; + bottom_panel.style.bottom = "0px"; + bottom_panel.style.left = "20px"; + bottom_panel.style.right = "20px"; + bottom_panel.style.height = "50px"; + + var brush = document.createElement("div"); + brush.id = "brush"; + brush.style.backgroundColor = "transparent"; + brush.style.outline = "1px dashed black"; + brush.style.boxShadow = "0 0 0 1px white"; + brush.style.borderRadius = "50%"; + brush.style.MozBorderRadius = "50%"; + brush.style.WebkitBorderRadius = "50%"; + brush.style.position = "absolute"; + brush.style.zIndex = 100; + brush.style.pointerEvents = "none"; + this.brush = brush; + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + document.body.appendChild(brush); + + var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + self.brush_size = event.target.value; + self.updateBrushPreview(self, null, null); + }); + var clearButton = this.createLeftButton("Clear", + () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); + }); + var cancelButton = this.createRightButton("Cancel", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.close(); + }); + var saveButton = this.createRightButton("Save", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.save(); + }); + + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + + bottom_panel.appendChild(clearButton); + bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(cancelButton); + bottom_panel.appendChild(brush_size_slider); + + this.element.style.display = "block"; + imgCanvas.style.position = "relative"; + imgCanvas.style.top = "200"; + imgCanvas.style.left = "0"; + + maskCanvas.style.position = "absolute"; + } + + show() { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); + + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; + + this.setlayout(imgCanvas, maskCanvas); + + // prepare content + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); + + this.setImages(imgCanvas, backupCanvas); + this.setEventHandler(maskCanvas); + } + + setImages(imgCanvas, backupCanvas) { + const imgCtx = imgCanvas.getContext('2d'); + const backupCtx = backupCanvas.getContext('2d'); + const maskCtx = this.maskCtx; + const maskCanvas = this.maskCanvas; + + // image load + const orig_image = new Image(); + window.addEventListener("resize", () => { + // repositioning + imgCanvas.width = window.innerWidth - 250; + imgCanvas.height = window.innerHeight - 200; + + // redraw image + let drawWidth = orig_image.width; + let drawHeight = orig_image.height; + if (orig_image.width > imgCanvas.width) { + drawWidth = imgCanvas.width; + drawHeight = (drawWidth / orig_image.width) * orig_image.height; + } + + if (drawHeight > imgCanvas.height) { + drawHeight = imgCanvas.height; + drawWidth = (drawHeight / orig_image.height) * orig_image.width; + } + + imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); + + // update mask + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); + maskCanvas.width = drawWidth; + maskCanvas.height = drawHeight; + maskCanvas.style.top = imgCanvas.offsetTop + "px"; + maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); + }); + + const filepath = ComfyApp.clipspace.images; + + const touched_image = new Image(); + + touched_image.onload = function() { + backupCanvas.width = touched_image.width; + backupCanvas.height = touched_image.height; + + prepareRGB(touched_image, backupCanvas, backupCtx); + }; + + const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) + alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.set('channel', 'a'); + touched_image.src = alpha_url; + + // original image load + orig_image.onload = function() { + window.dispatchEvent(new Event('resize')); + }; + + const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); + rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.set('channel', 'rgb'); + orig_image.src = rgb_url; + this.image = orig_image; + }g + + + setEventHandler(maskCanvas) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + const self = this; + maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + } + + brush_size = 10; + drawing_mode = false; + lastx = -1; + lasty = -1; + lasttime = 0; + + static handleKeyDown(event) { + const self = MaskEditorDialog.instance; + if (event.key === ']') { + self.brush_size = Math.min(self.brush_size+2, 100); + } else if (event.key === '[') { + self.brush_size = Math.max(self.brush_size-2, 1); + } + + self.updateBrushPreview(self); + } + + static handlePointerUp(event) { + event.preventDefault(); + MaskEditorDialog.instance.drawing_mode = false; + } + + updateBrushPreview(self) { + const brush = self.brush; + + var centerX = self.cursorX; + var centerY = self.cursorY; + + brush.style.width = self.brush_size * 2 + "px"; + brush.style.height = self.brush_size * 2 + "px"; + brush.style.left = (centerX - self.brush_size) + "px"; + brush.style.top = (centerY - self.brush_size) + "px"; + } + + handleWheelEvent(self, event) { + if(event.deltaY < 0) + self.brush_size = Math.min(self.brush_size+2, 100); + else + self.brush_size = Math.max(self.brush_size-2, 1); + + self.brush_slider_input.value = self.brush_size; + + self.updateBrushPreview(self); + } + + draw_move(self, event) { + event.preventDefault(); + + this.cursorX = event.pageX; + this.cursorY = event.pageY; + + self.updateBrushPreview(self); + + if (event instanceof TouchEvent || event.buttons == 1) { + var diff = performance.now() - self.lasttime; + + const maskRect = self.maskCanvas.getBoundingClientRect(); + + var x = event.offsetX; + var y = event.offsetY + + if(event.offsetX == null) { + x = event.targetTouches[0].clientX - maskRect.left; + } + + if(event.offsetY == null) { + y = event.targetTouches[0].clientY - maskRect.top; + } + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !this.drawing_mode) + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + } + + handlePointerDown(self, event) { + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + + if ([0, 2, 5].includes(event.button)) { + self.drawing_mode = true; + + event.preventDefault(); + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + self.maskCtx.beginPath(); + if (event.button == 0) { + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + } else { + self.maskCtx.globalCompositeOperation = "destination-out"; + } + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + self.lasttime = performance.now(); + } + } + + save() { + const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.drawImage(this.maskCanvas, + 0, 0, this.maskCanvas.width, this.maskCanvas.height, + 0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // paste mask data into alpha channel + const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); + + const formData = new FormData(); + const filename = "clipspace-mask-" + performance.now() + ".png"; + + const item = + { + "filename": filename, + "subfolder": "", + "type": "clipspace", + }; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[0] = item; + + if(ComfyApp.clipspace.widgets) { + const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + + if(index >= 0) + ComfyApp.clipspace.widgets[index].value = item; + } + + const dataURL = this.backupCanvas.toDataURL(); + const blob = dataURLToBlob(dataURL); + + const original_blob = loadedImageToBlob(this.image); + + formData.append('image', blob, filename); + formData.append('original_image', original_blob); + formData.append('type', "clipspace"); + + uploadMask(item, formData); + this.close(); + } +} + +app.registerExtension({ + name: "Comfy.MaskEditor", + init(app) { + const callback = + function () { + let dlg = new MaskEditorDialog(app); + dlg.show(); + }; + + const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 + ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + } +}); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 245605484..f4f7272db 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -25,6 +25,7 @@ export class ComfyApp { * @type {serialized node object} */ static clipspace = null; + static clipspace_invalidate_handler = null; constructor() { this.ui = new ComfyUI(this); @@ -143,22 +144,34 @@ export class ComfyApp { callback: (obj) => { var widgets = null; if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); } - let img = new Image(); var imgs = undefined; + var orig_imgs = undefined; if(this.imgs != undefined) { - img.src = this.imgs[0].src; - imgs = [img]; + imgs = []; + orig_imgs = []; + + for (let i = 0; i < this.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = this.imgs[i].src; + orig_imgs[i] = imgs[i]; + } } ComfyApp.clipspace = { 'widgets': widgets, 'imgs': imgs, - 'original_imgs': imgs, - 'images': this.images + 'original_imgs': orig_imgs, + 'images': this.images, + 'selectedIndex': 0, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action }; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } } }); @@ -167,48 +180,82 @@ export class ComfyApp { { content: "Paste (Clipspace)", callback: () => { - if(ComfyApp.clipspace != null) { - if(ComfyApp.clipspace.widgets != null && this.widgets != null) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop) { - prop.callback(value); - } - }); - } - + if(ComfyApp.clipspace) { // image paste - if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + if(ComfyApp.clipspace.imgs && this.imgs) { var filename = ""; if(this.images && ComfyApp.clipspace.images) { - this.images = ComfyApp.clipspace.images; + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + + } + else + app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; } - if(ComfyApp.clipspace.images != undefined) { - const clip_image = ComfyApp.clipspace.images[0]; + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + this.imgs = [img]; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); if(index_in_clip >= 0) { - filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + const item = ComfyApp.clipspace.widgets[index_in_clip].value; + if(item.type) + filename = `${item.filename} [${item.type}]`; + else + filename = item.filename; } } - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { - this.imgs = ComfyApp.clipspace.imgs; + // for Load Image node. + if(this.widgets) { + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "") { + const postfix = ' [clipspace]'; + if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { + filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); + } - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } } } } - this.trigger('changed'); + + // ensure render after update widget_value + if(ComfyApp.clipspace.widgets && this.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.callback(value); + } + }); + } } + + app.graph.setDirtyCanvas(true); } } ); @@ -1275,12 +1322,17 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] - if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - widget.options.values = def["input"]["required"][widget.name][0]; + if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { + console.log(widget.options.values = def["input"]["required"][widget.name][1].files); + widget.options.values = def["input"]["required"][widget.name][1].files; + } + else + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; + widget.callback(widget.value); } } } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 5accc9d86..77517aec1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -581,6 +581,7 @@ export class ComfyUI { }), $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }), $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index cd471bc93..4a72246db 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,6 +256,20 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, + FILE_COMBO(node, inputName, inputData) { + const base_dir = inputData[1].base_dir; + let defaultValue = inputData[1].files[0]; + + const files = [] + for(let i in inputData[1].files) { + files[i] = inputData[1].files[i]; + const postfix = ' [clipspace]'; + if(base_dir == 'input' && files[i].endsWith(postfix)) + files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); + } + + return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; + }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; From 850daf0416367ba39d10195540f5b735952f0ee7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 14:13:06 -0400 Subject: [PATCH 002/120] Masked editor changes. Add a way to upload to subfolders. Clean up code. Fix some issues. --- folder_paths.py | 9 ---- nodes.py | 8 ++-- server.py | 74 ++++++++++++------------------- web/extensions/core/maskeditor.js | 9 ++-- web/scripts/app.js | 66 ++++++++------------------- web/scripts/widgets.js | 52 +++++++++++++++------- 6 files changed, 93 insertions(+), 125 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 0acd22674..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,10 +57,6 @@ def get_input_directory(): global input_directory return input_directory -def get_clipspace_directory(): - global input_directory - return input_directory+"/clipspace" - #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -70,8 +66,6 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() - if type_name == "clipspace": - return get_clipspace_directory() return None @@ -87,9 +81,6 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] - elif name.endswith("[clipspace]"): - base_dir = get_clipspace_directory() - name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index 1d9a5c872..699e60ae8 100644 --- a/nodes.py +++ b/nodes.py @@ -973,9 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, + {"image": (sorted(files), )}, } CATEGORY = "image" @@ -1015,9 +1015,9 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + {"image": (sorted(files), ), "channel": (s._color_channels, ), } } diff --git a/server.py b/server.py index 48644d83a..3d02b2f7a 100644 --- a/server.py +++ b/server.py @@ -118,8 +118,6 @@ class PromptServer(): type_dir = folder_paths.get_input_directory() elif dir_type == "input": type_dir = folder_paths.get_input_directory() - elif dir_type == "clipspace": - type_dir = folder_paths.get_clipspace_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": @@ -127,73 +125,63 @@ class PromptServer(): return type_dir - @routes.post("/upload/image") - async def upload_image(request): - post = await request.post() + def image_upload(post, image_save_function=None): image = post.get("image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) + image_upload_type = post.get("type") + upload_dir = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename if not filename: return web.Response(status=400) + subfolder = post.get("subfolder", "") + full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) + + if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir: + return web.Response(status=400) + + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder) + split = os.path.splitext(filename) + filepath = os.path.join(full_output_folder, filename) + i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): + while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" i += 1 - filepath = os.path.join(upload_dir, filename) + if image_save_function is not None: + image_save_function(image, post, filepath) + else: + with open(filepath, "wb") as f: + f.write(image.file.read()) - with open(filepath, "wb") as f: - f.write(image.file.read()) - - return web.json_response({"name" : filename}) + return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) else: return web.Response(status=400) + @routes.post("/upload/image") + async def upload_image(request): + post = await request.post() + return image_upload(post) + @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() - image = post.get("image") - original_image = post.get("original_image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) - - if image and image.file: - filename = image.filename - if not filename: - return web.Response(status=400) - - split = os.path.splitext(filename) - i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): - filename = f"{split[0]} ({i}){split[1]}" - i += 1 - - filepath = os.path.join(upload_dir, filename) - - original_pil = Image.open(original_image.file).convert('RGBA') + def image_save_function(image, post, filepath): + original_pil = Image.open(post.get("original_image").file).convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA') # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) - return web.json_response({"name": filename}) - else: - return web.Response(status=400) - + return image_upload(post, image_save_function) @routes.get("/view") async def view_image(request): @@ -201,10 +189,6 @@ class PromptServer(): filename = request.rel_url.query["filename"] filename,output_dir = folder_paths.annotated_filepath(filename) - if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): - output_dir = folder_paths.get_clipspace_directory() - filename = filename[10:] - # validation for security: prevent accessing arbitrary path if filename[0] == '/' or '..' in filename: return web.Response(status=400) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index c55f841b6..0ffa50c69 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -546,8 +546,8 @@ class MaskEditorDialog extends ComfyDialog { const item = { "filename": filename, - "subfolder": "", - "type": "clipspace", + "subfolder": "clipspace", + "type": "input", }; if(ComfyApp.clipspace.images) @@ -567,7 +567,8 @@ class MaskEditorDialog extends ComfyDialog { formData.append('image', blob, filename); formData.append('original_image', original_blob); - formData.append('type', "clipspace"); + formData.append('type', "input"); + formData.append('subfolder', "clipspace"); uploadMask(item, formData); this.close(); diff --git a/web/scripts/app.js b/web/scripts/app.js index f4f7272db..c6c29e45b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -183,7 +183,6 @@ export class ComfyApp { if(ComfyApp.clipspace) { // image paste if(ComfyApp.clipspace.imgs && this.imgs) { - var filename = ""; if(this.images && ComfyApp.clipspace.images) { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; @@ -209,49 +208,25 @@ export class ComfyApp { } } } - - if(ComfyApp.clipspace.images) { - const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; - if(clip_image.subfolder != '') - filename = `${clip_image.subfolder}/`; - filename += `${clip_image.filename} [${clip_image.type}]`; - } - else if(ComfyApp.clipspace.widgets) { - const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); - if(index_in_clip >= 0) { - const item = ComfyApp.clipspace.widgets[index_in_clip].value; - if(item.type) - filename = `${item.filename} [${item.type}]`; - else - filename = item.filename; - } - } - - // for Load Image node. - if(this.widgets) { - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "") { - const postfix = ' [clipspace]'; - if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { - filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); - } - - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; - } - } - } } - // ensure render after update widget_value - if(ComfyApp.clipspace.widgets && this.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.callback(value); - } - }); + if(this.widgets) { + if(ComfyApp.clipspace.images) { + const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0) { + this.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } } } @@ -1323,12 +1298,7 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { - console.log(widget.options.values = def["input"]["required"][widget.name][1].files); - widget.options.values = def["input"]["required"][widget.name][1].files; - } - else - widget.options.values = def["input"]["required"][widget.name][0]; + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 4a72246db..65edc0392 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,20 +256,6 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, - FILE_COMBO(node, inputName, inputData) { - const base_dir = inputData[1].base_dir; - let defaultValue = inputData[1].files[0]; - - const files = [] - for(let i in inputData[1].files) { - files[i] = inputData[1].files[i]; - const postfix = ' [clipspace]'; - if(base_dir == 'input' && files[i].endsWith(postfix)) - files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); - } - - return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; - }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; @@ -280,10 +266,46 @@ export const ComfyWidgets = { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; - img.src = `/view?filename=${name}&type=input`; + let folder_separator = name.lastIndexOf("/"); + let subfolder = ""; + if (folder_separator > -1) { + subfolder = name.substring(0, folder_separator); + name = name.substring(folder_separator + 1); + } + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; node.setSizeForImage?.(); } + var default_value = imageWidget.value; + Object.defineProperty(imageWidget, "value", { + set : function(value) { + this._real_value = value; + }, + + get : function() { + let value = ""; + if (this._real_value) { + value = this._real_value; + } else { + return default_value; + } + + if (value.filename) { + let real_value = value; + value = ""; + if (real_value.subfolder) { + value = real_value.subfolder + "/"; + } + + value += real_value.filename; + + if(real_value.type && real_value.type !== "input") + value += ` [${real_value.type}]`; + } + return value; + } + }); + // Add our own callback to the combo widget to render an image when it changes const cb = node.callback; imageWidget.callback = function () { From a7ebd5aa1278a63f2f14852dce59b43834f6b9d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 15:52:33 -0400 Subject: [PATCH 003/120] Fix masked editor issue with firefox on windows. --- web/extensions/core/maskeditor.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 0ffa50c69..552059e86 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -368,7 +368,7 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (event instanceof TouchEvent || event.buttons == 1) { + if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -389,7 +389,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. brush_size *= this.last_pressure; } @@ -442,7 +442,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ brush_size *= this.last_pressure; } else { From a8705dbfe20ba86eaac5a669c61453775c796441 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 17:05:28 -0400 Subject: [PATCH 004/120] Speed up the mask save and fix refresh replacing copied image. --- server.py | 2 +- web/scripts/app.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 3d02b2f7a..c1226f304 100644 --- a/server.py +++ b/server.py @@ -179,7 +179,7 @@ class PromptServer(): # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) + original_pil.save(filepath, compress_level=4) return image_upload(post, image_save_function) diff --git a/web/scripts/app.js b/web/scripts/app.js index c6c29e45b..2da1b5581 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1300,7 +1300,7 @@ export class ComfyApp { if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { widget.options.values = def["input"]["required"][widget.name][0]; - if(!widget.options.values.includes(widget.value)) { + if(widget.name != 'image' && !widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; widget.callback(widget.value); } From c6e34963e412e1960f73ad357d10c2b7bd1464e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 18:15:19 -0400 Subject: [PATCH 005/120] Make t2i adapter work with any latent resolution. --- comfy/t2i_adapter/adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 0221fff83..87e3d859e 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -56,7 +56,12 @@ class Downsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels - return self.op(x) + if not self.use_conv: + padding = [x.shape[2] % 2, x.shape[3] % 2] + self.op.padding = padding + + x = self.op(x) + return x class ResnetBlock(nn.Module): From d43e45ce624b82dadbe98646329d2b0fbc17edcf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 10:29:58 -0400 Subject: [PATCH 006/120] Remove print. --- nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nodes.py b/nodes.py index 699e60ae8..760db24e1 100644 --- a/nodes.py +++ b/nodes.py @@ -443,7 +443,6 @@ class ControlNetApply: def apply_controlnet(self, conditioning, control_net, image, strength): c = [] control_hint = image.movedim(-1,1) - print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) From 314e526c5ce428a3717207c5c36a42a5c895b6a5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 12:18:18 -0400 Subject: [PATCH 007/120] Not needed anymore because sampling works with any latent size. --- comfy/samplers.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index dcf93cca2..6417f2ed4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -362,19 +362,8 @@ def resolve_cond_masks(conditions, h, w, device): else: box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) - # Make sure the height and width are divisible by 8 - if X % 8 != 0: - newx = X // 8 * 8 - W = W + (X - newx) - X = newx - if Y % 8 != 0: - newy = Y // 8 * 8 - H = H + (Y - newy) - Y = newy - if H % 8 != 0: - H = H + (8 - (H % 8)) - if W % 8 != 0: - W = W + (8 - (W % 8)) + H = max(8, H) + W = max(8, W) area = (int(H), int(W), int(Y), int(X)) modified['area'] = area From 02ca1c67f87e46e926aba325e73b2845d5244874 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 23:51:52 -0400 Subject: [PATCH 008/120] Don't print traceback when processing interrupted. --- execution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index c19c10bc6..edf884611 100644 --- a/execution.py +++ b/execution.py @@ -194,7 +194,10 @@ class PromptExecutor: if valid: recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: - print(traceback.format_exc()) + if isinstance(e, comfy.model_management.InterruptProcessingException): + print("Processing interrupted") + else: + print(traceback.format_exc()) to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From d6dee8af1df5e7dc80463b9e45bdce76767e4119 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 00:29:31 -0400 Subject: [PATCH 009/120] Only validate each input once. --- execution.py | 40 ++++++++++++++++++---------------------- main.py | 2 +- server.py | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/execution.py b/execution.py index edf884611..3953fde3a 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def execute(self, prompt, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -172,27 +172,15 @@ class PromptExecutor: executed = set() try: to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] + for x in list(execute_outputs): + to_execute += [(0, x)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -219,8 +207,11 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -241,8 +232,9 @@ def validate_inputs(prompt, item): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) + r = validate_inputs(prompt, o_id, validated) if r[0] == False: + validated[o_id] = r return r else: if type_input == "INT": @@ -270,7 +262,10 @@ def validate_inputs(prompt, item): if isinstance(type_input, list): if val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + + ret = (True, "") + validated[unique_id] = ret + return ret def validate_prompt(prompt): outputs = set() @@ -284,11 +279,12 @@ def validate_prompt(prompt): good_outputs = set() errors = [] + validated = {} for o in outputs: valid = False reason = "" try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] except Exception as e: @@ -297,7 +293,7 @@ def validate_prompt(prompt): reason = "Parsing error" if valid == True: - good_outputs.add(x) + good_outputs.add(o) else: print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") @@ -307,7 +303,7 @@ def validate_prompt(prompt): errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - return (True, "") + return (True, "", list(good_outputs)) class PromptQueue: diff --git a/main.py b/main.py index eb97a2fb8..d385df70a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) + e.execute(item[-3], item[-2], item[-1]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): diff --git a/server.py b/server.py index c1226f304..b6ac7d483 100644 --- a/server.py +++ b/server.py @@ -312,7 +312,7 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) else: resp_code = 400 out_string = valid[1] From 8e3d1cbf3b8488b319675f952e1a868aa78f1161 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 01:45:27 -0400 Subject: [PATCH 010/120] Fix bug when uploading image with the same name. --- server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server.py b/server.py index b6ac7d483..911f6a614 100644 --- a/server.py +++ b/server.py @@ -151,6 +151,7 @@ class PromptServer(): i = 1 while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) i += 1 if image_save_function is not None: From 51583164ef08d2173eb93eefa36bc50429cfe7c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 10:03:30 -0400 Subject: [PATCH 011/120] Make MaskToImage support masks with a batch size. --- comfy_extras/nodes_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 131cd6a9f..9916f3b21 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -72,7 +72,7 @@ class MaskToImage: FUNCTION = "mask_to_image" def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: From f7c0f75d1fb1c6e3657f69247eace796882c62da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 13:58:19 -0400 Subject: [PATCH 012/120] Auto batching improvements. Try batching when cond sizes don't match with smart padding. --- comfy/samplers.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6417f2ed4..aa44fa82d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,10 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +import math + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise @@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: - if c1['c_crossattn'].shape != c2['c_crossattn'].shape: - return False + s1 = c1['c_crossattn'].shape + s2 = c2['c_crossattn'].shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False @@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn = [] c_concat = [] c_adm = [] + crossattn_max_len = 0 for x in c_list: if 'c_crossattn' in x: - c_crossattn.append(x['c_crossattn']) + c = x['c_crossattn'] + if crossattn_max_len == 0: + crossattn_max_len = c.shape[1] + else: + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + c_crossattn.append(c) if 'c_concat' in x: c_concat.append(x['c_concat']) if 'c_adm' in x: c_adm.append(x['c_adm']) out = {} - if len(c_crossattn) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn)] + c_crossattn_out = [] + for c in c_crossattn: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + c_crossattn_out.append(c) + + if len(c_crossattn_out) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn_out)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] if len(c_adm) > 0: From 602095f614276dd52fad718c223e0be17d12b11e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:49:49 -0400 Subject: [PATCH 013/120] Send execution_error message on websocket on execution exception. --- execution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 3953fde3a..7ee038975 100644 --- a/execution.py +++ b/execution.py @@ -185,7 +185,11 @@ class PromptExecutor: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") else: - print(traceback.format_exc()) + message = str(traceback.format_exc()) + print(message) + if self.server.client_id is not None: + self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From 3a7c3acc72435f312a8f050d8ad3a1c902d9cff4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:59:24 -0400 Subject: [PATCH 014/120] Send websocket message with list of cached nodes right before execution. --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 7ee038975..7d18d3b65 100644 --- a/execution.py +++ b/execution.py @@ -169,6 +169,8 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) executed = set() try: to_execute = [] From 974958ff81d9af92b01490bcc99dfc93f8bb5d30 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 16:41:43 -0400 Subject: [PATCH 015/120] Make the prompt_id a uuid and return it when queueing the prompt. --- server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 911f6a614..6965ff3c1 100644 --- a/server.py +++ b/server.py @@ -81,7 +81,7 @@ class PromptServer(): # Reusing existing session, remove old self.sockets.pop(sid, None) else: - sid = uuid.uuid4().hex + sid = uuid.uuid4().hex self.sockets[sid] = ws @@ -313,7 +313,9 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) + prompt_id = str(uuid.uuid4()) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + return web.json_response({"prompt_id": prompt_id}) else: resp_code = 400 out_string = valid[1] From dfc74c19d944b4a4503e22297592fa3a537d3092 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 01:22:40 -0400 Subject: [PATCH 016/120] Add the prompt_id to some websocket messages. --- execution.py | 8 ++++---- main.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index 7d18d3b65..0ac4d462c 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}, execute_outputs=[]): + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -170,7 +170,7 @@ class PromptExecutor: current_outputs = set(self.outputs.keys()) if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() try: to_execute = [] @@ -190,7 +190,7 @@ class PromptExecutor: message = str(traceback.format_exc()) print(message) if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) to_delete = [] for o in self.outputs: @@ -207,7 +207,7 @@ class PromptExecutor: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None }, self.server.client_id) + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) gc.collect() comfy.model_management.soft_empty_cache() diff --git a/main.py b/main.py index d385df70a..00cbf3c4a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-3], item[-2], item[-1]) + e.execute(item[2], item[1], item[3], item[4]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): From 8ea165dd1ef877f58f3710f31ce43f27e0f739ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 14:15:13 -0400 Subject: [PATCH 017/120] Add a way to overwrite images when uploading. --- server.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 6965ff3c1..a2bb26ad9 100644 --- a/server.py +++ b/server.py @@ -127,6 +127,7 @@ class PromptServer(): def image_upload(post, image_save_function=None): image = post.get("image") + overwrite = post.get("overwrite") image_upload_type = post.get("type") upload_dir = get_dir_by_type(image_upload_type) @@ -148,11 +149,14 @@ class PromptServer(): split = os.path.splitext(filename) filepath = os.path.join(full_output_folder, filename) - i = 1 - while os.path.exists(filepath): - filename = f"{split[0]} ({i}){split[1]}" - filepath = os.path.join(full_output_folder, filename) - i += 1 + if overwrite is not None and (overwrite == "true" or overwrite == "1"): + pass + else: + i = 1 + while os.path.exists(filepath): + filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) + i += 1 if image_save_function is not None: image_save_function(image, post, filepath) From 8a4ff5e34cc53252a9ff23e796904100d75bea55 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 12 May 2023 20:58:29 +0100 Subject: [PATCH 018/120] allow static files to be symlinks --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index a2bb26ad9..ef858a98a 100644 --- a/server.py +++ b/server.py @@ -362,7 +362,7 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ - web.static('/', self.web_root), + web.static('/', self.web_root, follow_symlinks=True), ]) def get_queue_info(self): From d9e088ddfd97663abbb933c77f79d2a6c6127851 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:49:09 +0200 Subject: [PATCH 019/120] minor changes for tiled sampler --- comfy/ldm/modules/tomesd.py | 2 +- comfy/sd.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 6a13b80c9..bb971e88f 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, """ B, N, _ = metric.shape - if r <= 0: + if r <= 0 or w == 1 or h == 1: return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather diff --git a/comfy/sd.py b/comfy/sd.py index 3543bdb77..0200f7742 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -581,10 +581,7 @@ class VAE: samples = samples.cpu() return samples -def resize_image_to(tensor, target_latent_tensor, batched_number): - tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") - target_batch_size = target_latent_tensor.shape[0] - +def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] print(current_batch_size, target_batch_size) if current_batch_size == 1: @@ -623,7 +620,9 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -794,10 +793,14 @@ class T2IAdapter: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint + self.control_input = None self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint) self.t2i_model.cpu() From 19c014f4292863444a3d677d504ad58623395a58 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:57:40 +0200 Subject: [PATCH 020/120] comment out annoying print statement --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 0200f7742..c6be900ad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -583,7 +583,7 @@ class VAE: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - print(current_batch_size, target_batch_size) + #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor From c5c0ea666f8456b5a788092bad88528bbf34f559 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 12 May 2023 20:34:48 -0400 Subject: [PATCH 021/120] noise_mask in latent should be in a single format. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 760db24e1..c2201dafc 100644 --- a/nodes.py +++ b/nodes.py @@ -795,7 +795,7 @@ class SetLatentNoiseMask: def set_mask(self, samples, mask): s = samples.copy() - s["noise_mask"] = mask + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): From 997dd1b1312a00cbedeafaf916e49f294a73a431 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 02:07:49 -0400 Subject: [PATCH 022/120] Fix queue delete. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index a2bb26ad9..8435d091b 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) + delete_func = lambda a: a[1] == id_to_delete self.prompt_queue.delete_queue_item(delete_func) - + return web.Response(status=200) @routes.post("/interrupt") From 1201d2eae5820bb8124beb22b712d743415fd47d Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 13 May 2023 17:15:45 +0200 Subject: [PATCH 023/120] Make nodes map over input lists (#579) * allow nodes to map over lists * make work with IS_CHANGED and VALIDATE_INPUTS * give list outputs distinct socket shape * add rebatch node * add batch index logic * add repeat latent batch * deal with noise mask edge cases in latentfrombatch --- comfy/sample.py | 17 ++++-- comfy_extras/nodes_rebatch.py | 108 ++++++++++++++++++++++++++++++++++ execution.py | 90 +++++++++++++++++++++++----- nodes.py | 57 +++++++++++++++--- server.py | 1 + web/scripts/app.js | 3 +- 6 files changed, 250 insertions(+), 26 deletions(-) create mode 100644 comfy_extras/nodes_rebatch.py diff --git a/comfy/sample.py b/comfy/sample.py index bd38585ac..284efca61 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..0a9daf272 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,108 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/execution.py b/execution.py index 0ac4d462c..cf2e5ea71 100644 --- a/execution.py +++ b/execution.py @@ -26,20 +26,81 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = prompt + input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": - input_data_all[x] = unique_id + input_data_all[x] = [unique_id] return input_data_all +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + intput_is_list = False + if hasattr(obj, "INPUT_IS_LIST"): + intput_is_list = obj.INPUT_IS_LIST + + max_len_input = max([len(x) for x in input_data_all.values()]) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + d_new = dict() + for k,v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + results = [] + if intput_is_list: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + ui = dict() + if len(uis) > 0: + ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + return output, ui + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -63,13 +124,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute server.send_sync("executing", { "node": unique_id }, server.client_id) obj = class_def() - nodes.before_node_execution() - outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) - if "ui" in outputs[unique_id]: + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) - if "result" in outputs[unique_id]: - outputs[unique_id] = outputs[unique_id]["result"] + server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -261,9 +321,11 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): if val not in type_input: diff --git a/nodes.py b/nodes.py index c2201dafc..509dc0697 100644 --- a/nodes.py +++ b/nodes.py @@ -629,18 +629,57 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() + if "noise_mask" in samples: + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() + if "batch_index" not in s: + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: @@ -805,8 +844,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -1170,6 +1209,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1244,6 +1284,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image", @@ -1299,3 +1341,4 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) diff --git a/server.py b/server.py index 8435d091b..cb66cc618 100644 --- a/server.py +++ b/server.py @@ -268,6 +268,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = x info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x diff --git a/web/scripts/app.js b/web/scripts/app.js index 2da1b5581..1a4a18b94 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -976,7 +976,8 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - this.addOutput(outputName, output); + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); } const s = this.computeSize(); From 44f9f9baf170ddf27891b240002300d8aa09fb2a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:17:16 -0400 Subject: [PATCH 024/120] Add the prompt id to some websocket messages. --- execution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index cf2e5ea71..b9548229c 100644 --- a/execution.py +++ b/execution.py @@ -101,7 +101,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -116,19 +116,19 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = class_def() output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -215,6 +215,9 @@ class PromptExecutor: else: self.server.client_id = None + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] @@ -242,7 +245,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") From cb4b8223981ec9e090ebf44205f5ce16d72f01cb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:54:45 -0400 Subject: [PATCH 025/120] Print custom nodes that take too much time to import. --- nodes.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nodes.py b/nodes.py index 509dc0697..bc7968308 100644 --- a/nodes.py +++ b/nodes.py @@ -6,6 +6,7 @@ import json import hashlib import traceback import math +import time from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -1325,6 +1326,7 @@ def load_custom_node(module_path): def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") + node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: @@ -1333,7 +1335,16 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + time_before = time.time() load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path)) + + slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) + if len(slow_nodes) > 0: + print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + for n in sorted(slow_nodes): + print("{:6.1f} seconds to import:".format(n[0]), n[1]) + print() def init_custom_nodes(): load_custom_nodes() From cf439709b6b3ffae5ad15a9f7e59fedc214d5f1c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 12:50:21 -0400 Subject: [PATCH 026/120] Load nodes in comfy_extras before custom nodes. Change the slow import message. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index bc7968308..956b739d9 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,15 +1341,15 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + print("\nDetected some custom nodes that were slow to import:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() def init_custom_nodes(): - load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_nodes() From 92bf1cb61efcab45961d1119cb7ec7a076caf24e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:05:52 -0400 Subject: [PATCH 027/120] Change message. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 956b739d9..28215127c 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,7 +1341,7 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import:") + print("\nImport times for custom nodes:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From 2ac744f6628d107b3534177eeca5ef06f6668609 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:15:31 -0400 Subject: [PATCH 028/120] Print all custom node import times. --- nodes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 28215127c..f3b7da1a9 100644 --- a/nodes.py +++ b/nodes.py @@ -1339,10 +1339,9 @@ def load_custom_nodes(): load_custom_node(module_path) node_import_times.append((time.time() - time_before, module_path)) - slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) - if len(slow_nodes) > 0: + if len(node_import_times) > 0: print("\nImport times for custom nodes:") - for n in sorted(slow_nodes): + for n in sorted(node_import_times): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From db4d3a8494a4a7dbb6f911ae126a92abec6bf91b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:23:42 -0400 Subject: [PATCH 029/120] Print if custom nodes imported successfully or not. --- nodes.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index f3b7da1a9..63d9adc3d 100644 --- a/nodes.py +++ b/nodes.py @@ -1318,11 +1318,14 @@ def load_custom_node(module_path): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) + return True else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + return False except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + return False def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") @@ -1336,13 +1339,17 @@ def load_custom_nodes(): module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue time_before = time.time() - load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path)) + success = load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") for n in sorted(node_import_times): - print("{:6.1f} seconds to import:".format(n[0]), n[1]) + if n[2]: + import_message = "" + else: + import_message = " (IMPORT FAILED)" + print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) print() def init_custom_nodes(): From b0505eb7ab8af1986dabd97c23fae83a0539303d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 15:31:22 -0400 Subject: [PATCH 030/120] Return right type when none specified in upload route. Switch time.time to time.perf_counter for custom node import times. --- nodes.py | 4 ++-- server.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index 63d9adc3d..c4aff1012 100644 --- a/nodes.py +++ b/nodes.py @@ -1338,9 +1338,9 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - time_before = time.time() + time_before = time.perf_counter() success = load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path, success)) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") diff --git a/server.py b/server.py index d1079dd83..ba4dcba03 100644 --- a/server.py +++ b/server.py @@ -115,22 +115,23 @@ class PromptServer(): def get_dir_by_type(dir_type): if dir_type is None: - type_dir = folder_paths.get_input_directory() - elif dir_type == "input": + dir_type = "input" + + if dir_type == "input": type_dir = folder_paths.get_input_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": type_dir = folder_paths.get_output_directory() - return type_dir + return type_dir, dir_type def image_upload(post, image_save_function=None): image = post.get("image") overwrite = post.get("overwrite") image_upload_type = post.get("type") - upload_dir = get_dir_by_type(image_upload_type) + upload_dir, image_upload_type = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename From 3a1f47764d76bb9878b55e82657044b3faceda9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 17:11:27 -0400 Subject: [PATCH 031/120] Print the torch device that is used on startup. --- comfy/model_management.py | 42 ++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39df8d9a7..c15323219 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -127,6 +127,32 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() + +def get_torch_device_name(device): + if hasattr(device, 'type'): + return "{}".format(device.type) + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + +try: + print("Using device:", get_torch_device_name(get_torch_device())) +except: + print("Could not pick default device.") + current_loaded_model = None current_gpu_controlnets = [] @@ -233,22 +259,6 @@ def unload_if_low_vram(model): return model.cpu() return model -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type From e7b9d2c02cffd59fecca4ee617137ea38641078a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:30:58 -0400 Subject: [PATCH 032/120] /prompt endpoint error is now in json format. --- server.py | 7 +++---- web/scripts/api.js | 2 +- web/scripts/app.js | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index ba4dcba03..f52117f10 100644 --- a/server.py +++ b/server.py @@ -323,12 +323,11 @@ class PromptServer(): self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) return web.json_response({"prompt_id": prompt_id}) else: - resp_code = 400 - out_string = valid[1] print("invalid prompt:", valid[1]) + return web.json_response({"error": valid[1]}, status=400) + else: + return web.json_response({"error": "no prompt"}, status=400) - return web.Response(body=out_string, status=resp_code) - @routes.post("/queue") async def post_queue(request): json_data = await request.json() diff --git a/web/scripts/api.js b/web/scripts/api.js index d29faa5ba..4f061c358 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -163,7 +163,7 @@ class ComfyApi extends EventTarget { if (res.status !== 200) { throw { - response: await res.text(), + response: await res.json(), }; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 1a4a18b94..00d3c9746 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1222,7 +1222,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response || error.toString()); + this.ui.dialog.show(error.response.error || error.toString()); break; } From 9bf67c4c5a5c8b8d1efc2d4ce7e7ab1eccce1fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:34:25 -0400 Subject: [PATCH 033/120] Print prompt execution time. --- execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/execution.py b/execution.py index b9548229c..dd88029bc 100644 --- a/execution.py +++ b/execution.py @@ -6,6 +6,7 @@ import threading import heapq import traceback import gc +import time import torch import nodes @@ -215,6 +216,7 @@ class PromptExecutor: else: self.server.client_id = None + execution_start_time = time.perf_counter() if self.server.client_id is not None: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) @@ -272,6 +274,7 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() comfy.model_management.soft_empty_cache() From d926f65f56217e7828ad27ec5b646c74398593c4 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 14 May 2023 23:21:22 +0900 Subject: [PATCH 034/120] Feature/maskeditor context menu (#649) * add "Open in MaskEditor" to context menu * change save button name to 'Save to node' if open in node. clear clipspace_return_node after auto paste * * leak patch: prevent infinite duplication of MaskEditorDialog instance on every dialog open * prevent conflict of multiple opening of MaskEditorDialog * name of save button fix * patch: brushPreview hiding by dialog * consider close by 'esc' key on maskeditor. * bugfix about last patch * patch: invalid close detection * 'enter' key as save action * * batch support enhance - pick index based on imageIndex on copy action * paste fix on batch image node * typo --------- Co-authored-by: Lt.Dr.Data --- web/extensions/core/maskeditor.js | 120 ++++++++++++---- web/scripts/app.js | 226 +++++++++++++++++------------- 2 files changed, 221 insertions(+), 125 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 552059e86..4b0c12747 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) { class MaskEditorDialog extends ComfyDialog { static instance = null; + + static getInstance() { + if(!MaskEditorDialog.instance) { + MaskEditorDialog.instance = new MaskEditorDialog(app); + } + + return MaskEditorDialog.instance; + } + + is_layout_created = false; + constructor() { super(); this.element = $el("div.comfy-modal", { parent: document.body }, [ $el("div.comfy-modal-content", [...this.createButtons()]), ]); - MaskEditorDialog.instance = this; } createButtons() { return []; } - clearMask(self) { - } - createButton(name, callback) { var button = document.createElement("button"); button.innerText = name; button.addEventListener("click", callback); return button; } + createLeftButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "left"; button.style.marginRight = "4px"; return button; } + createRightButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "right"; button.style.marginLeft = "4px"; return button; } + createLeftSlider(self, name, callback) { const divElement = document.createElement('div'); divElement.id = "maskeditor-slider"; @@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog { brush.style.MozBorderRadius = "50%"; brush.style.WebkitBorderRadius = "50%"; brush.style.position = "absolute"; - brush.style.zIndex = 100; + brush.style.zIndex = 8889; brush.style.pointerEvents = "none"; this.brush = brush; this.element.appendChild(imgCanvas); @@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog { document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.close(); }); - var saveButton = this.createRightButton("Save", () => { + + this.saveButton = this.createRightButton("Save", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.save(); @@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); - bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(brush_size_slider); - this.element.style.display = "block"; imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; @@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog { } show() { - // layout - const imgCanvas = document.createElement('canvas'); - const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); + if(!this.is_layout_created) { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); - imgCanvas.id = "imageCanvas"; - maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; - this.setlayout(imgCanvas, maskCanvas); + this.setlayout(imgCanvas, maskCanvas); - // prepare content - this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + // prepare content + this.imgCanvas = imgCanvas; + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); - this.setImages(imgCanvas, backupCanvas); - this.setEventHandler(maskCanvas); + this.setEventHandler(maskCanvas); + + this.is_layout_created = true; + + // replacement of onClose hook since close is not real close + const self = this; + const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + ComfyApp.onClipspaceEditorClosed(); + } + + self.last_display_style = self.element.style.display; + } + }); + }); + + const config = { attributes: true }; + observer.observe(this.element, config); + } + + this.setImages(this.imgCanvas, this.backupCanvas); + + if(ComfyApp.clipspace_return_node) { + this.saveButton.innerText = "Save to node"; + } + else { + this.saveButton.innerText = "Save"; + } + this.saveButton.disabled = false; + + this.element.style.display = "block"; + this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + } + + isOpened() { + return this.element.style.display == "block"; } setImages(imgCanvas, backupCanvas) { @@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog { const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); + maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); + // image load const orig_image = new Image(); window.addEventListener("resize", () => { @@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog { rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; - }g - + } setEventHandler(maskCanvas) { maskCanvas.addEventListener("contextmenu", (event) => { @@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog { self.brush_size = Math.min(self.brush_size+2, 100); } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + } else if(event.key === 'Enter') { + self.save(); } self.updateBrushPreview(self); @@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog { } } - save() { + async save() { const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); @@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog { formData.append('type', "input"); formData.append('subfolder', "clipspace"); - uploadMask(item, formData); + this.saveButton.innerText = "Saving..."; + this.saveButton.disabled = true; + await uploadMask(item, formData); + ComfyApp.onClipspaceEditorSave(); this.close(); } } @@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog { app.registerExtension({ name: "Comfy.MaskEditor", init(app) { - const callback = + ComfyApp.open_maskeditor = function () { - let dlg = new MaskEditorDialog(app); - dlg.show(); + const dlg = MaskEditorDialog.getInstance(); + if(!dlg.isOpened()) { + dlg.show(); + } }; const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 - ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); } }); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 00d3c9746..87c5e30ca 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -26,6 +26,8 @@ export class ComfyApp { */ static clipspace = null; static clipspace_invalidate_handler = null; + static open_maskeditor = null; + static clipspace_return_node = null; constructor() { this.ui = new ComfyUI(this); @@ -49,6 +51,114 @@ export class ComfyApp { this.shiftDown = false; } + static isImageNode(node) { + return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); + } + + static onClipspaceEditorSave() { + if(ComfyApp.clipspace_return_node) { + ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); + } + } + + static onClipspaceEditorClosed() { + ComfyApp.clipspace_return_node = null; + } + + static copyToClipspace(node) { + var widgets = null; + if(node.widgets) { + widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + var imgs = undefined; + var orig_imgs = undefined; + if(node.imgs != undefined) { + imgs = []; + orig_imgs = []; + + for (let i = 0; i < node.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = node.imgs[i].src; + orig_imgs[i] = imgs[i]; + } + } + + var selectedIndex = 0; + if(node.imageIndex) { + selectedIndex = node.imageIndex; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': orig_imgs, + 'images': node.images, + 'selectedIndex': selectedIndex, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action + }; + + ComfyApp.clipspace_return_node = null; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } + } + + static pasteFromClipspace(node) { + if(ComfyApp.clipspace) { + // image paste + if(ComfyApp.clipspace.imgs && node.imgs) { + if(node.images && ComfyApp.clipspace.images) { + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + } + else + app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + node.imgs = [img]; + node.imageIndex = 0; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); + if(index >= 0) { + node.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } + } + + app.graph.setDirtyCanvas(true); + } + } + /** * Invoke an extension callback * @param {keyof ComfyExtension} method The extension callback to execute @@ -138,102 +248,30 @@ export class ComfyApp { } } - options.push( - { - content: "Copy (Clipspace)", - callback: (obj) => { - var widgets = null; - if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); - } - - var imgs = undefined; - var orig_imgs = undefined; - if(this.imgs != undefined) { - imgs = []; - orig_imgs = []; + // prevent conflict of clipspace content + if(!ComfyApp.clipspace_return_node) { + options.push({ + content: "Copy (Clipspace)", + callback: (obj) => { ComfyApp.copyToClipspace(this); } + }); - for (let i = 0; i < this.imgs.length; i++) { - imgs[i] = new Image(); - imgs[i].src = this.imgs[i].src; - orig_imgs[i] = imgs[i]; + if(ComfyApp.clipspace != null) { + options.push({ + content: "Paste (Clipspace)", + callback: () => { ComfyApp.pasteFromClipspace(this); } + }); + } + + if(ComfyApp.isImageNode(this)) { + options.push({ + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); } - } - - ComfyApp.clipspace = { - 'widgets': widgets, - 'imgs': imgs, - 'original_imgs': orig_imgs, - 'images': this.images, - 'selectedIndex': 0, - 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action - }; - - if(ComfyApp.clipspace_invalidate_handler) { - ComfyApp.clipspace_invalidate_handler(); - } - } - }); - - if(ComfyApp.clipspace != null) { - options.push( - { - content: "Paste (Clipspace)", - callback: () => { - if(ComfyApp.clipspace) { - // image paste - if(ComfyApp.clipspace.imgs && this.imgs) { - if(this.images && ComfyApp.clipspace.images) { - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; - - } - else - app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; - } - - if(ComfyApp.clipspace.imgs) { - // deep-copy to cut link with clipspace - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - const img = new Image(); - img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; - this.imgs = [img]; - } - else { - const imgs = []; - for(let i=0; i obj.name === 'image'); - if(index >= 0) { - this.widgets[index].value = clip_image; - } - } - if(ComfyApp.clipspace.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.value = value; - prop.callback(value); - } - }); - } - } - } - - app.graph.setDirtyCanvas(true); - } - } - ); + }); + } } }; } From acff543d669dba9b03fb500a10010f2da8739ff3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 12:50:21 -0400 Subject: [PATCH 035/120] Remove useless code. --- nodes.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/nodes.py b/nodes.py index c4aff1012..bc23e5c17 100644 --- a/nodes.py +++ b/nodes.py @@ -146,9 +146,6 @@ class ConditioningSetMask: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -161,9 +158,6 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -176,9 +170,6 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -203,9 +194,6 @@ class VAEEncode: return ({"samples":t}, ) class VAEEncodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -220,9 +208,6 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} From 587f89fe5a8e2bcb389fb4919dc33c330320fa41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 15:10:40 -0400 Subject: [PATCH 036/120] Enable safe loading for upscale models. --- comfy_extras/nodes_upscale_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ab5b0ccfc..f9252ea0b 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -17,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) out = model_loading.load_state_dict(sd).eval() return (out, ) From 84ea21c815d426000c233e0c7b8c542764335cc8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 17:02:40 -0400 Subject: [PATCH 037/120] Update litegraph from upstream. --- web/lib/litegraph.core.js | 145 +++++++++++++++++++++++++++++++++++--- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2bc6af0c3..6c81c3ffd 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action) //when clicked on top of a node //and it is not interactive - if (node && this.allow_interaction && !skip_action && !this.read_only) { + if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) { if (!this.live_mode && !node.flags.pinned) { this.bringToFront(node); } //if it wasn't selected? //not dragging mouse to connect two slots - if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { + if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) @@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action) } //double clicking - if (is_double_click && this.selected_nodes[node.id]) { + if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) { //double click node if (node.onDblClick) { node.onDblClick( e, pos, this ); @@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action) this.dirty_canvas = true; } + //get node over + var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); + if (this.dragging_rectangle) { this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; @@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action) this.ds.offset[1] += delta[1] / this.ds.scale; this.dirty_canvas = true; this.dirty_bgcanvas = true; - } else if (this.allow_interaction && !this.read_only) { + } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) { if (this.connecting_node) { this.dirty_canvas = true; } - //get node over - var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); - //remove mouseover flag for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { @@ -9911,7 +9911,7 @@ LGraphNode.prototype.executeAction = function(action) event, active_widget ) { - if (!node.widgets || !node.widgets.length) { + if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) { return null; } @@ -10300,6 +10300,119 @@ LGraphNode.prototype.executeAction = function(action) canvas.graph.add(group); }; + /** + * Determines the furthest nodes in each direction + * @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.getBoundaryNodes = function(nodes) { + let top = null; + let right = null; + let bottom = null; + let left = null; + for (const nID in nodes) { + const node = nodes[nID]; + const [x, y] = node.pos; + const [width, height] = node.size; + + if (top === null || y < top.pos[1]) { + top = node; + } + if (right === null || x + width > right.pos[0] + right.size[0]) { + right = node; + } + if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) { + bottom = node; + } + if (left === null || x < left.pos[0]) { + left = node; + } + } + + return { + "top": top, + "right": right, + "bottom": bottom, + "left": left + }; + } + /** + * Determines the furthest nodes in each direction for the currently selected nodes + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.prototype.boundaryNodesForSelection = function() { + return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes)); + } + + /** + * + * @param {LGraphNode[]} nodes a list of nodes + * @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes + * @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction) + */ + LGraphCanvas.alignNodes = function (nodes, direction, align_to) { + if (!nodes) { + return; + } + + const canvas = LGraphCanvas.active_canvas; + let boundaryNodes = [] + if (align_to === undefined) { + boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes) + } else { + boundaryNodes = { + "top": align_to, + "right": align_to, + "bottom": align_to, + "left": align_to + } + } + + for (const [_, node] of Object.entries(canvas.selected_nodes)) { + switch (direction) { + case "right": + node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0]; + break; + case "left": + node.pos[0] = boundaryNodes["left"].pos[0]; + break; + case "top": + node.pos[1] = boundaryNodes["top"].pos[1]; + break; + case "bottom": + node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1]; + break; + } + } + + canvas.dirty_canvas = true; + canvas.dirty_bgcanvas = true; + }; + + LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node); + } + } + + LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase()); + } + } + LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { var canvas = LGraphCanvas.active_canvas; @@ -12900,6 +13013,14 @@ LGraphNode.prototype.executeAction = function(action) options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); }*/ + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align", + has_submenu: true, + callback: LGraphCanvas.onGroupAlign, + }) + } + if (this._graph_stack && this._graph_stack.length > 0) { options.push(null, { content: "Close subgraph", @@ -13014,6 +13135,14 @@ LGraphNode.prototype.executeAction = function(action) callback: LGraphCanvas.onMenuNodeToSubgraph }); + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align Selected To", + has_submenu: true, + callback: LGraphCanvas.onNodeAlign, + }) + } + options.push(null, { content: "Remove", disabled: !(node.removable !== false && !node.block_delete ), From 1dd846a7bad8cfab679a0976e201c722871c6917 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:27:28 -0400 Subject: [PATCH 038/120] Fix outputs gone from history. --- execution.py | 16 +++++++++++----- main.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index dd88029bc..0e2cc15c1 100644 --- a/execution.py +++ b/execution.py @@ -102,7 +102,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -117,7 +117,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -128,6 +128,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) @@ -205,6 +206,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.outputs_ui = {} self.old_prompt = {} self.server = server @@ -234,6 +236,11 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() @@ -247,7 +254,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -413,8 +420,7 @@ class PromptQueue: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: - if "ui" in outputs[o]: - self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"] + self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 00cbf3c4a..50d3b9a62 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ def prompt_worker(q, server): while True: item, item_id = q.get() e.execute(item[2], item[1], item[3], item[4]) - q.task_done(item_id, e.outputs) + q.task_done(item_id, e.outputs_ui) async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From ef815ba1e24eef45041adec8a55ecd628b20476f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:29:56 -0400 Subject: [PATCH 039/120] Switch default scheduler to normal. --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index aa44fa82d..fccf254ec 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,7 +495,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] From c02a554bcf6ef50f8e252c89dc0a56c08d4955c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:25:24 -0400 Subject: [PATCH 040/120] Make DiffusersLoader work with subfolders. --- nodes.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index bc23e5c17..797ad6c9c 100644 --- a/nodes.py +++ b/nodes.py @@ -282,7 +282,10 @@ class DiffusersLoader: paths = [] for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths += next(os.walk(search_path))[1] + for root, subdir, files in os.walk(search_path, followlinks=True): + if "model_index.json" in files: + paths.append(os.path.relpath(root, start=search_path)) + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -292,9 +295,9 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] - if model_path in paths: - model_path = os.path.join(search_path, model_path) + path = os.path.join(search_path, model_path) + if os.path.exists(path): + model_path = path break return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) From 2ec6d1c6e364ab92e3d8149a83873ac47c797248 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:31:03 -0400 Subject: [PATCH 041/120] Don't import custom nodes when the folder ends with .disabled --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 797ad6c9c..e8b36c24a 100644 --- a/nodes.py +++ b/nodes.py @@ -1326,6 +1326,7 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + if module_path.endswith(".disabled"): continue time_before = time.perf_counter() success = load_custom_node(module_path) node_import_times.append((time.perf_counter() - time_before, module_path, success)) From 5f7968f1fafb2cf5d15fe049fc53265ad0fc6696 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 01:12:44 -0400 Subject: [PATCH 042/120] Print the endpoint ip for localtunnel in the colab notebook. --- notebooks/comfyui_colab.ipynb | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index fecfa6707..c5a209eec 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -175,6 +175,8 @@ "import threading\n", "import time\n", "import socket\n", + "import urllib.request\n", + "\n", "def iframe_thread(port):\n", " while True:\n", " time.sleep(0.5)\n", @@ -183,7 +185,9 @@ " if result == 0:\n", " break\n", " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", + " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", + "\n", + " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " for line in p.stdout:\n", " print(line.decode(), end='')\n", From 13d94caf49b21bd129ec867b04641973e3a102da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 03:18:11 -0400 Subject: [PATCH 043/120] Add control_after_generate to combo primitive. --- web/extensions/core/widgetInputs.js | 2 +- web/scripts/widgets.js | 80 +++++++++++++++++++---------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..4fe0a6013 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -300,7 +300,7 @@ app.registerExtension({ } } - if (widget.type === "number") { + if (widget.type === "number" || widget.type === "combo") { addValueControlWidget(this, widget, "fixed"); } diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 65edc0392..3d1acc53e 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,35 +19,61 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - let min = targetWidget.options.min; - let max = targetWidget.options.max; - // limit to something that javascript can handle - max = Math.min(1125899906842624, max); - min = Math.max(-1125899906842624, min); - let range = (max - min) / (targetWidget.options.step / 10); + console.log(targetWidget); + if (targetWidget.type == "combo" && v !== "fixed") { + let current_index = targetWidget.options.values.indexOf(targetWidget.value); + let current_length = targetWidget.options.values.length; - //adjust values based on valueControl Behaviour - switch (v) { - case "fixed": - break; - case "increment": - targetWidget.value += targetWidget.options.step / 10; - break; - case "decrement": - targetWidget.value -= targetWidget.options.step / 10; - break; - case "randomize": - targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; - default: - break; + switch (v) { + case "increment": + current_index += 1; + break; + case "decrement": + current_index -= 1; + break; + case "randomize": + current_index = Math.floor(Math.random() * current_length); + default: + break; + } + current_index = Math.max(0, current_index); + current_index = Math.min(current_length - 1, current_index); + if (current_index >= 0) { + let value = targetWidget.options.values[current_index]; + targetWidget.value = value; + targetWidget.callback(value); + } + } else { //number + let min = targetWidget.options.min; + let max = targetWidget.options.max; + // limit to something that javascript can handle + max = Math.min(1125899906842624, max); + min = Math.max(-1125899906842624, min); + let range = (max - min) / (targetWidget.options.step / 10); + + //adjust values based on valueControl Behaviour + switch (v) { + case "fixed": + break; + case "increment": + targetWidget.value += targetWidget.options.step / 10; + break; + case "decrement": + targetWidget.value -= targetWidget.options.step / 10; + break; + case "randomize": + targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; + default: + break; + } + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) + targetWidget.value = min; + + if (targetWidget.value > max) + targetWidget.value = max; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; - - if (targetWidget.value > max) - targetWidget.value = max; } return valueControl; }; From 7ada9e7d85f93495aa5006468a45220932f5e988 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Tue, 16 May 2023 22:55:00 +0900 Subject: [PATCH 044/120] allows touch drag --- web/scripts/app.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 87c5e30ca..ef3b44c83 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -902,7 +902,9 @@ export class ComfyApp { await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM - const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + const mainCanvas = document.createElement("canvas") + mainCanvas.style.touchAction = "none" + const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" })); canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); From 11e7168d56e0987e52d0afb620189f08bda2b454 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 11:55:16 -0400 Subject: [PATCH 045/120] Remove print. --- web/scripts/widgets.js | 1 - 1 file changed, 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 3d1acc53e..94988d0f2 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,7 +19,6 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - console.log(targetWidget); if (targetWidget.type == "combo" && v !== "fixed") { let current_index = targetWidget.options.values.indexOf(targetWidget.value); let current_length = targetWidget.options.values.length; From 4088e61aa6b8943e28ee243c0b1265c41974ef67 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 15:35:07 -0400 Subject: [PATCH 046/120] Update litegraph from upstream. --- web/lib/litegraph.core.js | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 6c81c3ffd..95f4a2735 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action) if (show_text) { ctx.textAlign = "center"; ctx.fillStyle = text_color; - ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); + ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7); } break; case "toggle": @@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); if (show_text) { ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.textAlign = "right"; @@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed(3), widget_width * 0.5, y + H * 0.7 ); @@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); } ctx.fillStyle = secondary_text_color; - ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); + ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7); ctx.fillStyle = text_color; ctx.textAlign = "right"; if (w.type == "number") { @@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action) //ctx.stroke(); ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = text_color; ctx.textAlign = "right"; From e7f2816c6f1da22e2018cf088bd45110ff265c79 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 18 May 2023 12:40:28 +0900 Subject: [PATCH 047/120] feat:Latent Save/Load (#662) * wip * latent dir * fix * fix * now working * mark todo * remove server.py changes to separate PRt --------- Co-authored-by: Lt.Dr.Data --- input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index e8b36c24a..a2c7713aa 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,8 @@ import importlib import folder_paths +import safetensors.torch as sft + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -246,6 +248,91 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") + self.type = "output" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(self.output_dir, subfolder) + + if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: + print("Saving latent outside the 'input/latents' folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"workflow": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + sft.save_file(samples, file, metadata=metadata) + + return {} + + +class LoadLatent: + input_dir = os.path.join(folder_paths.get_input_directory(), "latents") + + @classmethod + def INPUT_TYPES(s): + files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + file = folder_paths.get_annotated_filepath(latent, self.input_dir) + + latent = sft.load_file(file, device="cpu") + + return (latent, ) + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -1235,6 +1322,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = { From a7375103b9c80bb7607f85faa4afbf11ab5a5685 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:04:40 -0400 Subject: [PATCH 048/120] Some small changes to Load/SaveLatent. --- nodes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index a2c7713aa..7255621d7 100644 --- a/nodes.py +++ b/nodes.py @@ -11,6 +11,7 @@ import time from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np +import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -29,7 +30,6 @@ import importlib import folder_paths -import safetensors.torch as sft def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -307,7 +307,10 @@ class SaveLatent: file = f"{filename}_{counter:05}_.latent" file = os.path.join(full_output_folder, file) - sft.save_file(samples, file, metadata=metadata) + output = {} + output["latent_tensor"] = samples["samples"] + + safetensors.torch.save_file(output, file, metadata=metadata) return {} @@ -328,9 +331,10 @@ class LoadLatent: def load(self, latent): file = folder_paths.get_annotated_filepath(latent, self.input_dir) - latent = sft.load_file(file, device="cpu") + latent = safetensors.torch.load_file(file, device="cpu") + samples = {"samples": latent["latent_tensor"]} - return (latent, ) + return (samples, ) class CheckpointLoader: From faf899ad5ae32f770f0dae6a9df457e81d2b5c38 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:43:59 -0400 Subject: [PATCH 049/120] LoadLatent and SaveLatent should behave like the LoadImage and SaveImage. --- folder_paths.py | 33 +++++++ input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++-------------- 3 files changed, 55 insertions(+), 68 deletions(-) delete mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..28f117824 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -147,4 +147,37 @@ def get_filename_list(folder_name): output_list.update(filter_files_extensions(recursive_search(x), folders[1])) return sorted(list(output_list)) +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + def compute_vars(input, image_width, image_height): + input = input.replace("%width%", str(image_width)) + input = input.replace("%height%", str(image_height)) + return input + + filename_prefix = compute_vars(filename_prefix, image_width, image_height) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(output_dir, subfolder) + + if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: + print("Saving image outside the output folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + return full_output_folder, filename, counter, subfolder, filename_prefix diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here deleted file mode 100644 index e69de29bb..000000000 diff --git a/nodes.py b/nodes.py index 7255621d7..7b450df38 100644 --- a/nodes.py +++ b/nodes.py @@ -251,13 +251,12 @@ class VAEEncodeForInpaint: class SaveLatent: def __init__(self): - self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") - self.type = "output" + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () @@ -268,31 +267,7 @@ class SaveLatent: CATEGORY = "_for_testing" def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving latent outside the 'input/latents' folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) # support save metadata for latent sharing prompt_info = "" @@ -316,11 +291,10 @@ class SaveLatent: class LoadLatent: - input_dir = os.path.join(folder_paths.get_input_directory(), "latents") - @classmethod def INPUT_TYPES(s): - files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } CATEGORY = "_for_testing" @@ -329,13 +303,25 @@ class LoadLatent: FUNCTION = "load" def load(self, latent): - file = folder_paths.get_annotated_filepath(latent, self.input_dir) - - latent = safetensors.torch.load_file(file, device="cpu") + latent_path = folder_paths.get_annotated_filepath(latent) + latent = safetensors.torch.load_file(latent_path, device="cpu") samples = {"samples": latent["latent_tensor"]} - return (samples, ) + @classmethod + def IS_CHANGED(s, latent): + image_path = folder_paths.get_annotated_filepath(latent) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, latent): + if not folder_paths.exists_annotated_filepath(latent): + return "Invalid latent file: {}".format(latent) + return True + class CheckpointLoader: @classmethod @@ -1020,39 +1006,7 @@ class SaveImage: CATEGORY = "image" def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - filename_prefix = compute_vars(filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving image outside the output folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() for image in images: i = 255. * image.cpu().numpy() From 62a371e12b4763bf6f9aeb42ff4928138df6ae26 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 18 May 2023 02:41:21 -0400 Subject: [PATCH 050/120] Load workflow from latent file. --- nodes.py | 2 +- web/scripts/app.js | 7 ++++++- web/scripts/pnginfo.js | 16 ++++++++++++++++ web/scripts/ui.js | 2 +- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 7b450df38..3c61cd2ec 100644 --- a/nodes.py +++ b/nodes.py @@ -274,7 +274,7 @@ class SaveLatent: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"workflow": prompt_info} + metadata = {"prompt": prompt_info} if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) diff --git a/web/scripts/app.js b/web/scripts/app.js index ef3b44c83..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111 } from "./pnginfo.js"; +import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -1308,6 +1308,11 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); + } else if (file.name?.endsWith(".latent")) { + const info = await getLatentMetadata(file); + if (info.workflow) { + this.loadGraphData(JSON.parse(info.workflow)); + } } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 209b562a6..8ddb7a1c5 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,6 +47,22 @@ export function getPngMetadata(file) { }); } +export function getLatentMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const safetensorsData = new Uint8Array(event.target.result); + const dataView = new DataView(safetensorsData.buffer); + let header_size = dataView.getUint32(0, true); + let offset = 8; + let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + r(header.__metadata__); + }; + + reader.readAsArrayBuffer(file); + }); +} + export async function importA1111(graph, parameters) { const p = parameters.lastIndexOf("\nSteps:"); if (p > -1) { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 77517aec1..2c9043d00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png", + accept: ".json,image/png,.latent", style: { display: "none" }, parent: document.body, onchange: () => { From 8bbd9815a976ef43e2665d45c5afb4a21c06c831 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 02:15:32 -0400 Subject: [PATCH 051/120] Support loading fp16 latent files. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 3c61cd2ec..878e0b955 100644 --- a/nodes.py +++ b/nodes.py @@ -305,7 +305,7 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"]} + samples = {"samples": latent["latent_tensor"].float()} return (samples, ) @classmethod From 2998e232cb26b66e7ba42a53ada3a8285fcb2c15 Mon Sep 17 00:00:00 2001 From: malern <701073+malern@users.noreply.github.com> Date: Fri, 19 May 2023 19:57:15 +0100 Subject: [PATCH 052/120] Make multiline widget work with different canvas dimensions. It now scales the textarea positioning using the canvas height/width. --- web/scripts/widgets.js | 20 +++++++++++++------- web/style.css | 2 ++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 94988d0f2..82168b08b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -155,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) { computeSize(node.size); } const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const t = ctx.getTransform(); const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + Object.assign(this.inputEl.style, { - left: `${t.a * margin + t.e}px`, - top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, - width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, - background: (!node.color)?'':node.color, - height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, + transformOrigin: "0 0", + transform: transform, + left: "0px", + top: "0px", + width: `${widgetWidth - (margin * 2)}px`, + height: `${this.parent.inputHeight - (margin * 2)}px`, position: "absolute", + background: (!node.color)?'':node.color, color: (!node.color)?'':'white', zIndex: app.graph._nodes.indexOf(node), - fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; }, diff --git a/web/style.css b/web/style.css index df220cc02..87f096e14 100644 --- a/web/style.css +++ b/web/style.css @@ -39,6 +39,8 @@ body { padding: 2px; resize: none; border: none; + box-sizing: border-box; + font-size: 10px; } .comfy-modal { From e6e1999f96adbe0f4b041d265837e00bde9283ab Mon Sep 17 00:00:00 2001 From: malern <701073+malern@users.noreply.github.com> Date: Fri, 19 May 2023 20:04:36 +0100 Subject: [PATCH 053/120] Render UI at a higher resolution when viewing with a higher pixel ratio --- web/scripts/app.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..514ca3958 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -921,8 +921,9 @@ export class ComfyApp { this.graph.start(); function resizeCanvas() { - canvasEl.width = canvasEl.offsetWidth; - canvasEl.height = canvasEl.offsetHeight; + canvasEl.width = canvasEl.offsetWidth * window.devicePixelRatio; + canvasEl.height = canvasEl.offsetHeight * window.devicePixelRatio; + canvasEl.getContext("2d").scale(window.devicePixelRatio, window.devicePixelRatio); canvas.draw(true, true); } From b9daf4e30f32e76a00628add0849d84b5ec2fe76 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 22:40:28 -0400 Subject: [PATCH 054/120] Add a /object_info/{node_class} route to get only the info of one node. --- server.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/server.py b/server.py index f52117f10..18ce54306 100644 --- a/server.py +++ b/server.py @@ -261,23 +261,34 @@ class PromptServer(): async def get_prompt(request): return web.json_response(self.get_queue_info()) + def node_info(node_class): + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + return info + @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: - obj_class = nodes.NODE_CLASS_MAPPINGS[x] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = x - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - out[x] = info + out[x] = node_info(x) + return web.json_response(out) + + @routes.get("/object_info/{node_class}") + async def get_object_info_node(request): + node_class = request.match_info.get("node_class", None) + out = {} + if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): + out[node_class] = node_info(node_class) return web.json_response(out) @routes.get("/history") From 36af98d75580292d4afbeafb5e7ba7f010145436 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 20 May 2023 15:23:28 +0200 Subject: [PATCH 055/120] improve sharpen and blur nodes --- comfy_extras/nodes_post_processing.py | 35 ++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ba699e2b8..37c824bde 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,6 +59,12 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) +def gaussian_kernel(kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + class Blur: def __init__(self): pass @@ -88,12 +94,6 @@ class Blur: CATEGORY = "image/postprocessing" - def gaussian_kernel(self, kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): if blur_radius == 0: return (image,) @@ -101,10 +101,11 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) @@ -167,9 +168,15 @@ class Sharpen: "max": 31, "step": 1 }), - "alpha": ("FLOAT", { + "sigma": ("FLOAT", { "default": 1.0, "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.0, "max": 5.0, "step": 0.1 }), @@ -181,21 +188,21 @@ class Sharpen: CATEGORY = "image/postprocessing" - def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): if sharpen_radius == 0: return (image,) batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = sharpened.permute(0, 2, 3, 1) result = torch.clamp(sharpened, 0, 1) From 71666f248f769af073408a3475dd7a82a29d8247 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 10:08:47 -0400 Subject: [PATCH 056/120] Fix padding in Blur. --- comfy_extras/nodes_post_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 37c824bde..3be141dfe 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -105,7 +105,7 @@ class Blur: image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) From 797c4e8d3b56559bb205ff8aab97d97dca424b9a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:07:21 -0400 Subject: [PATCH 057/120] Simplify and improve some vae attention code. --- comfy/ldm/modules/diffusionmodules/model.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5e4d2b60f..05caf7312 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -331,25 +331,13 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.transpose(2, 3).reshape(B, C, H, W) out = self.proj_out(out) return x+out From b8636a44aacd83ec6a9a19a6d3d3f5b76fc863c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:43:39 -0400 Subject: [PATCH 058/120] Make scaled_dot_product switch to sliced attention on OOM. --- comfy/ldm/modules/diffusionmodules/model.py | 79 +++++++++++---------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 05caf7312..91e7d60ec 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -146,6 +146,41 @@ class ResnetBlock(nn.Module): return x+h +def slice_attention(q, k, v): + r1 = torch.zeros_like(k, device=q.device) + scale = (int(q.shape[-1])**(-0.5)) + + mem_free_total = model_management.get_free_memory(q.device) + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = torch.bmm(q[:, i:end], k) * scale + + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) + del s1 + + r1[:, :, i:end] = torch.bmm(v, s2) + del s2 + break + except model_management.OOM_EXCEPTION as e: + steps *= 2 + if steps > 128: + raise e + print("out of memory error, increasing steps and trying again", steps) + + return r1 class AttnBlock(nn.Module): def __init__(self, in_channels): @@ -183,48 +218,15 @@ class AttnBlock(nn.Module): # compute attention b,c,h,w = q.shape - scale = (int(c)**(-0.5)) q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw v = v.reshape(b,c,h*w) - r1 = torch.zeros_like(k, device=q.device) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = torch.bmm(q[:, i:end], k) * scale - - s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) - del s1 - - r1[:, :, i:end] = torch.bmm(v, s2) - del s2 - break - except model_management.OOM_EXCEPTION as e: - steps *= 2 - if steps > 128: - raise e - print("out of memory error, increasing steps and trying again", steps) - + r1 = slice_attention(q, k, v) h_ = r1.reshape(b,c,h,w) del r1 - h_ = self.proj_out(h_) return x+h_ @@ -335,9 +337,14 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(2, 3).reshape(B, C, H, W) + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out From 3c76f43057f140c583962327c18c7d5257e7495c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 23:06:33 -0400 Subject: [PATCH 059/120] Cleaner code. --- server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 18ce54306..701c0e7a7 100644 --- a/server.py +++ b/server.py @@ -331,7 +331,8 @@ class PromptServer(): extra_data["client_id"] = json_data["client_id"] if valid[0]: prompt_id = str(uuid.uuid4()) - self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + outputs_to_execute = valid[2] + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) From 516119ad835841bd176055cbc888843b418b8004 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 00:24:28 -0400 Subject: [PATCH 060/120] Print min and max values in validation error message. --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 0e2cc15c1..35f044346 100644 --- a/execution.py +++ b/execution.py @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) From 069657fbf3d8d977ead39ab206d8c917bbcc4997 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 01:35:08 -0400 Subject: [PATCH 061/120] Add DPM-Solver++(2M) SDE and exponential scheduler. exponential scheduler is the one recommended with this sampler. --- comfy/k_diffusion/sampling.py | 43 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 6 +++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c809d39fb..94d7a5762 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -605,3 +605,46 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index fccf254ec..1fb928f8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", - "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -532,6 +532,8 @@ class KSampler: if self.scheduler == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) elif self.scheduler == "normal": sigmas = self.model_wrap.get_sigmas(steps) elif self.scheduler == "simple": From 4796e615dd7faad38429fc8e716e3a817a28c526 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 10:34:26 -0400 Subject: [PATCH 062/120] Revert DPI fix since it caused more issues than it solved. --- web/scripts/app.js | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 514ca3958..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -921,9 +921,8 @@ export class ComfyApp { this.graph.start(); function resizeCanvas() { - canvasEl.width = canvasEl.offsetWidth * window.devicePixelRatio; - canvasEl.height = canvasEl.offsetHeight * window.devicePixelRatio; - canvasEl.getContext("2d").scale(window.devicePixelRatio, window.devicePixelRatio); + canvasEl.width = canvasEl.offsetWidth; + canvasEl.height = canvasEl.offsetHeight; canvas.draw(true, true); } From dc198650c0d2d281c9d87b23a8917b457a94d837 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 11:34:29 -0400 Subject: [PATCH 063/120] sample_dpmpp_2m_sde no longer crashes when step == 1. --- comfy/k_diffusion/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 94d7a5762..c540d7411 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -628,6 +628,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised + h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() From 6cc450579b3314fe314e64af22c4be81afd1f87d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 00:22:24 -0400 Subject: [PATCH 064/120] Auto transpose images from exif data. --- comfy/k_diffusion/sampling.py | 2 +- nodes.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c540d7411..26930428f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -620,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = None h_last = None + h = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -628,7 +629,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised - h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() diff --git a/nodes.py b/nodes.py index 878e0b955..bae330bc9 100644 --- a/nodes.py +++ b/nodes.py @@ -8,7 +8,7 @@ import traceback import math import time -from PIL import Image +from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -1057,6 +1057,7 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1100,6 +1101,7 @@ class LoadImageMask: def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") mask = None From ffc56c53c9cccfcc21c92fe14cb095bb32ea2744 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:22:38 -0400 Subject: [PATCH 065/120] Add a node_errors to the /prompt error json response. "node_errors" contains a dict keyed by node ids. The contents are a message and a list of dependent outputs. --- execution.py | 27 ++++++++++++++++----------- server.py | 4 ++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index 35f044346..212e789ca 100644 --- a/execution.py +++ b/execution.py @@ -299,18 +299,18 @@ def validate_inputs(prompt, item, validated): required_inputs = class_inputs['required'] for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) + return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) + return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) r = validate_inputs(prompt, o_id, validated) if r[0] == False: validated[o_id] = r @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) @@ -338,13 +338,13 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for r in ret: if r != True: - return (False, "{}, {}".format(class_type, r)) + return (False, "{}, {}".format(class_type, r), unique_id) else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) - ret = (True, "") + ret = (True, "", unique_id) validated[unique_id] = ret return ret @@ -356,10 +356,11 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs") + return (False, "Prompt has no outputs", [], []) good_outputs = set() errors = [] + node_errors = {} validated = {} for o in outputs: valid = False @@ -368,6 +369,7 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] + node_id = m[2] except Exception as e: print(traceback.format_exc()) valid = False @@ -379,12 +381,15 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) - return (True, "", list(good_outputs)) + return (True, "", list(good_outputs), node_errors) class PromptQueue: diff --git a/server.py b/server.py index 701c0e7a7..8429a63fb 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) - return web.json_response({"error": valid[1]}, status=400) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt"}, status=400) + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @routes.post("/queue") async def post_queue(request): From db27b0405a31983916d6801cf84f7f1fc4503e6a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:25:50 -0400 Subject: [PATCH 066/120] object_info now returns if node is an output_node or not. --- server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server.py b/server.py index 8429a63fb..c0f79cbd5 100644 --- a/server.py +++ b/server.py @@ -272,6 +272,11 @@ class PromptServer(): info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = '' info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY return info From bfb13f5eee48545f1c4b0b8a377de80be84bb100 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 17:05:23 -0400 Subject: [PATCH 067/120] Remove useless call to /object_info --- web/extensions/core/colorPalette.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 2f2238a2b..bfcd847a3 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,7 +174,7 @@ const els = {} // const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, - init() { + addCustomNodeDefs(node_defs) { const sortObjectKeys = (unordered) => { return Object.keys(unordered).sort().reduce((obj, key) => { obj[key] = unordered[key]; @@ -182,10 +182,10 @@ app.registerExtension({ }, {}); }; - const getSlotTypes = async () => { + function getSlotTypes() { var types = []; - const defs = await api.getNodeDefs(); + const defs = node_defs; for (const nodeId in defs) { const nodeData = defs[nodeId]; @@ -212,8 +212,8 @@ app.registerExtension({ return types; }; - const completeColorPalette = async (colorPalette) => { - var types = await getSlotTypes(); + function completeColorPalette(colorPalette) { + var types = getSlotTypes(); for (const type of types) { if (!colorPalette.colors.node_slot[type]) { From 48fcc5b777b3a1ab5d6dc5fec6adaebeb32c2c93 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 20:51:30 -0400 Subject: [PATCH 068/120] Parsing error crash. --- execution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 212e789ca..25f2fcacd 100644 --- a/execution.py +++ b/execution.py @@ -374,6 +374,7 @@ def validate_prompt(prompt): print(traceback.format_exc()) valid = False reason = "Parsing error" + node_id = None if valid == True: good_outputs.add(o) @@ -381,9 +382,10 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + if node_id is not None: + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) From 34887b888546716b5c5507606289ca2728bf3123 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 03:12:56 -0400 Subject: [PATCH 069/120] Add experimental bislerp algorithm for latent upscaling. It's like bilinear but with slerp. --- comfy/utils.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++- nodes.py | 2 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 09e05d4ed..0f7b34503 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,6 +46,65 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +#slow and inefficient, should be optimized +def bislerp(samples, width, height): + shape = list(samples.shape) + width_scale = (shape[3]) / (width ) + height_scale = (shape[2]) / (height ) + + shape[3] = width + shape[2] = height + out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + + def algorithm(in1, w1, in2, w2): + dims = in1.shape + val = w2 + + #flatten to batches + low = in1.reshape(dims[0], -1) + high = in2.reshape(dims[0], -1) + + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + + # in case we divide by zero + low_norm[low_norm != low_norm] = 0.0 + high_norm[high_norm != high_norm] = 0.0 + + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res.reshape(dims) + + for x_dest in range(shape[3]): + for y_dest in range(shape[2]): + y = (y_dest) * height_scale + x = (x_dest) * width_scale + + x1 = max(math.floor(x), 0) + x2 = min(x1 + 1, samples.shape[3] - 1) + y1 = max(math.floor(y), 0) + y2 = min(y1 + 1, samples.shape[2] - 1) + + in1 = samples[:,:,y1,x1] + in2 = samples[:,:,y1,x2] + in3 = samples[:,:,y2,x1] + in4 = samples[:,:,y2,x2] + + if (x1 == x2) and (y1 == y2): + out_value = in1 + elif (x1 == x2): + out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + elif (y1 == y2): + out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + else: + o1 = algorithm(in1, (x2 - x), in2, (x - x1)) + o2 = algorithm(in3, (x2 - x), in4, (x - x1)) + out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + + out1[:,:,y_dest,x_dest] = out_value + return out1 + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -61,7 +120,11 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples[:,:,y:old_height-y,x:old_width-x] else: s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if upscale_method == "bislerp": + return bislerp(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) diff --git a/nodes.py b/nodes.py index bae330bc9..e5cec2632 100644 --- a/nodes.py +++ b/nodes.py @@ -749,7 +749,7 @@ class RepeatLatentBatch: return (s,) class LatentUpscale: - upscale_methods = ["nearest-exact", "bilinear", "area"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] crop_methods = ["disabled", "center"] @classmethod From 451fb4169ad900e5d33b540f039f56ced9a76157 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:35:32 -0400 Subject: [PATCH 070/120] Fix 'git pull' not working on the standalones. --- .github/workflows/windows_release_cu118_package.yml | 1 + .github/workflows/windows_release_nightly_pytorch.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml index 15322c86a..2d6048a23 100644 --- a/.github/workflows/windows_release_cu118_package.yml +++ b/.github/workflows/windows_release_cu118_package.yml @@ -30,6 +30,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - shell: bash run: | cd .. diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b6a18ec0a..767a7216b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - uses: actions/setup-python@v4 with: python-version: '3.11.3' From b8ccbec6d893d34dab90d2418a3fe00969251fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:40:24 -0400 Subject: [PATCH 071/120] Various improvements to bislerp. --- comfy/utils.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 0f7b34503..300eda6aa 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -56,35 +56,42 @@ def bislerp(samples, width, height): shape[2] = height out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) - def algorithm(in1, w1, in2, w2): + def algorithm(in1, in2, t): dims = in1.shape - val = w2 + val = t #flatten to batches low = in1.reshape(dims[0], -1) high = in2.reshape(dims[0], -1) - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) + low_weight = torch.norm(low, dim=1, keepdim=True) + low_weight[low_weight == 0] = 0.0000000001 + low_norm = low/low_weight + high_weight = torch.norm(high, dim=1, keepdim=True) + high_weight[high_weight == 0] = 0.0000000001 + high_norm = high/high_weight - # in case we divide by zero - low_norm[low_norm != low_norm] = 0.0 - high_norm[high_norm != high_norm] = 0.0 - - omega = torch.acos((low_norm*high_norm).sum(1)) + dot_prod = (low_norm*high_norm).sum(1) + dot_prod[dot_prod > 0.9995] = 0.9995 + dot_prod[dot_prod < -0.9995] = -0.9995 + omega = torch.acos(dot_prod) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm + res *= (low_weight * (1.0-val) + high_weight * val) return res.reshape(dims) for x_dest in range(shape[3]): for y_dest in range(shape[2]): - y = (y_dest) * height_scale - x = (x_dest) * width_scale + y = (y_dest + 0.5) * height_scale - 0.5 + x = (x_dest + 0.5) * width_scale - 0.5 x1 = max(math.floor(x), 0) x2 = min(x1 + 1, samples.shape[3] - 1) + wx = x - math.floor(x) + y1 = max(math.floor(y), 0) y2 = min(y1 + 1, samples.shape[2] - 1) + wy = y - math.floor(y) in1 = samples[:,:,y1,x1] in2 = samples[:,:,y1,x2] @@ -94,13 +101,13 @@ def bislerp(samples, width, height): if (x1 == x2) and (y1 == y2): out_value = in1 elif (x1 == x2): - out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + out_value = algorithm(in1, in3, wy) elif (y1 == y2): - out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + out_value = algorithm(in1, in2, wx) else: - o1 = algorithm(in1, (x2 - x), in2, (x - x1)) - o2 = algorithm(in3, (x2 - x), in4, (x - x1)) - out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + o1 = algorithm(in1, in2, wx) + o2 = algorithm(in3, in4, wx) + out_value = algorithm(o1, o2, wy) out1[:,:,y_dest,x_dest] = out_value return out1 From c00bb1a0b78f0d2cf2e4ec2dd9ae7d61cb07a637 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 12:53:38 -0400 Subject: [PATCH 072/120] Add a latent upscale by node. --- nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nodes.py b/nodes.py index e5cec2632..f0a93ebd5 100644 --- a/nodes.py +++ b/nodes.py @@ -768,6 +768,25 @@ class LatentUpscale: s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) +class LatentUpscaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "upscale" + + CATEGORY = "latent" + + def upscale(self, samples, upscale_method, scale_by): + s = samples.copy() + width = round(samples["samples"].shape[3] * scale_by) + height = round(samples["samples"].shape[2] * scale_by) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") + return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): @@ -1244,6 +1263,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentUpscaleBy": LatentUpscaleBy, "LatentFromBatch": LatentFromBatch, "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, @@ -1322,6 +1342,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", + "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", From 7310290f17aad79480edb92f22cd58f0997db964 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 22:26:50 -0400 Subject: [PATCH 073/120] Pull in latest upscale model code from chainner. --- .../architecture/OmniSR/ChannelAttention.py | 110 ++++ .../architecture/OmniSR/LICENSE | 201 ++++++ .../architecture/OmniSR/OSA.py | 577 ++++++++++++++++++ .../architecture/OmniSR/OSAG.py | 60 ++ .../architecture/OmniSR/OmniSR.py | 133 ++++ .../architecture/OmniSR/esa.py | 294 +++++++++ .../architecture/OmniSR/layernorm.py | 70 +++ .../architecture/OmniSR/pixelshuffle.py | 31 + .../chainner_models/architecture/RRDB.py | 17 +- .../chainner_models/architecture/block.py | 30 + comfy_extras/chainner_models/model_loading.py | 5 + comfy_extras/chainner_models/types.py | 4 +- 12 files changed, 1530 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/LICENSE create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSA.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSAG.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/esa.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/layernorm.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py new file mode 100644 index 000000000..f4d52aa1e --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py @@ -0,0 +1,110 @@ +import math + +import torch.nn as nn + + +class CA_layer(nn.Module): + def __init__(self, channel, reduction=16): + super(CA_layer, self).__init__() + # global average pooling + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False), + nn.GELU(), + nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False), + # nn.Sigmoid() + ) + + def forward(self, x): + y = self.fc(self.gap(x)) + return x * y.expand_as(x) + + +class Simple_CA_layer(nn.Module): + def __init__(self, channel): + super(Simple_CA_layer, self).__init__() + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=channel, + out_channels=channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return x * self.fc(self.gap(x)) + + +class ECA_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ECA_MaxPool_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_MaxPool_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.max_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py new file mode 100644 index 000000000..d7a129696 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSA.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:07:42 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce +from torch import einsum, nn + +from .layernorm import LayerNorm2d + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +# helper classes + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class Conv_PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = LayerNorm2d(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim, 1, 1, 0), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Gated_Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=1, bias=False, dropout=0.0): + super().__init__() + + hidden_features = int(dim * mult) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +# MBConv + + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + Rearrange("b c -> b c 1 1"), + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = ( + torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() + > self.prob + ) + return x * keep_mask / (1 - self.prob) + + +def MBConv( + dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d( + hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim + ), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + # nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout=dropout) + + return net + + +# attention related classes +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) + ) + + # relative positional bias + if self.with_pe: + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, "c i j -> (i j) c") + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( + grid, "j ... -> 1 j ..." + ) + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( + dim=-1 + ) + + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = ( + *x.shape, + x.device, + self.heads, + ) + + # flatten + + x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # add positional bias + if self.with_pe: + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, "i j h -> h i j") + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + + out = rearrange( + out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width + ) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) + + +class Block_Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + bias=False, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.ps = window_size + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + # project for queries, keys, values + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # split heads + + q, k, v = map( + lambda t: rearrange( + t, + "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", + h=self.heads, + w1=self.ps, + w2=self.ps, + ), + (q, k, v), + ) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # attention + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + out = rearrange( + out, + "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", + x=h // self.ps, + y=w // self.ps, + head=self.heads, + w1=self.ps, + w2=self.ps, + ) + + out = self.to_out(out) + return out + + +class Channel_Attention(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class Channel_Attention_grid(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention_grid, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class OSA_Block(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + ffn_bias=True, + window_size=8, + with_pe=False, + dropout=0.0, + ): + super(OSA_Block, self).__init__() + + w = window_size + + self.layer = nn.Sequential( + MBConv( + channel_num, + channel_num, + downsample=False, + expansion_rate=1, + shrinkage_rate=0.25, + ), + Rearrange( + "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w + ), # block-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + Rearrange( + "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w + ), # grid-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention_grid( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + ) + + def forward(self, x): + out = self.layer(x) + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py new file mode 100644 index 000000000..477e81f9d --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSAG.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:08:49 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn + +from .esa import ESA +from .OSA import OSA_Block + + +class OSAG(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + block_num=4, + ffn_bias=False, + window_size=0, + pe=False, + ): + super(OSAG, self).__init__() + + # print("window_size: %d" % (window_size)) + # print("with_pe", pe) + # print("ffn_bias: %d" % (ffn_bias)) + + # block_script_name = kwargs.get("block_script_name", "OSA") + # block_class_name = kwargs.get("block_class_name", "OSA_Block") + + # script_name = "." + block_script_name + # package = __import__(script_name, fromlist=True) + block_class = OSA_Block # getattr(package, block_class_name) + group_list = [] + for _ in range(block_num): + temp_res = block_class( + channel_num, + bias, + ffn_bias=ffn_bias, + window_size=window_size, + with_pe=pe, + ) + group_list.append(temp_res) + group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) + self.residual_layer = nn.Sequential(*group_list) + esa_channel = max(channel_num // 4, 16) + self.esa = ESA(esa_channel, channel_num) + + def forward(self, x): + out = self.residual_layer(x) + out = out + x + return self.esa(out) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py new file mode 100644 index 000000000..dec169520 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OmniSR.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:06:36 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .OSAG import OSAG +from .pixelshuffle import pixelshuffle_block + + +class OmniSR(nn.Module): + def __init__( + self, + state_dict, + **kwargs, + ): + super(OmniSR, self).__init__() + self.state = state_dict + + bias = True # Fine to assume this for now + block_num = 1 # Fine to assume this for now + ffn_bias = True + pe = True + + num_feat = state_dict["input.weight"].shape[0] or 64 + num_in_ch = state_dict["input.weight"].shape[1] or 3 + num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + + pixelshuffle_shape = state_dict["up.0.weight"].shape[0] + up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) + if up_scale - int(up_scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + up_scale = int(up_scale) + res_num = 0 + for key in state_dict.keys(): + if "residual_layer" in key: + temp_res_num = int(key.split(".")[1]) + if temp_res_num > res_num: + res_num = temp_res_num + res_num = res_num + 1 # zero-indexed + + residual_layer = [] + self.res_num = res_num + + self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + self.up_scale = up_scale + + for _ in range(res_num): + temp_res = OSAG( + channel_num=num_feat, + bias=bias, + block_num=block_num, + ffn_bias=ffn_bias, + window_size=self.window_size, + pe=pe, + ) + residual_layer.append(temp_res) + self.residual_layer = nn.Sequential(*residual_layer) + self.input = nn.Conv2d( + in_channels=num_in_ch, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.output = nn.Conv2d( + in_channels=num_feat, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias) + + # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + # chaiNNer specific stuff + self.model_arch = "OmniSR" + self.sub_type = "SR" + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.scale = up_scale + + self.supports_fp16 = True # TODO: Test this + self.supports_bfp16 = True + self.min_size_restriction = 16 + + self.load_state_dict(state_dict, strict=False) + + def check_image_size(self, x): + _, _, h, w = x.size() + # import pdb; pdb.set_trace() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0) + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + residual = self.input(x) + out = self.residual_layer(residual) + + # origin + out = torch.add(self.output(out), residual) + out = self.up(out) + + out = out[:, :, : H * self.up_scale, : W * self.up_scale] + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py new file mode 100644 index 000000000..f9ce7f7a6 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/esa.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: esa.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:06 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layernorm import LayerNorm2d + + +def moment(x, dim=(2, 3), k=2): + assert len(x.size()) == 4 + mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) + mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) + return mk + + +class ESA(nn.Module): + """ + Modification of Enhanced Spatial Attention (ESA), which is proposed by + `Residual Feature Aggregation Network for Image Super-Resolution` + Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes + are deleted. + """ + + def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = conv(f, f, kernel_size=3, padding=1) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA_LN(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA_LN, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.norm = LayerNorm2d(n_feats) + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.norm(x) + c1_ = self.conv1(c1_) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class AdaGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaGuidedFilter, self).__init__() + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=n_feats, + out_channels=1, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.r = 5 + + def box_filter(self, x, r): + channel = x.shape[1] + kernel_size = 2 * r + 1 + weight = 1.0 / (kernel_size**2) + box_kernel = weight * torch.ones( + (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device + ) + output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) + return output + + def forward(self, x): + _, _, H, W = x.shape + N = self.box_filter( + torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r + ) + + # epsilon = self.fc(self.gap(x)) + # epsilon = torch.pow(epsilon, 2) + epsilon = 1e-2 + + mean_x = self.box_filter(x, self.r) / N + var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x + + A = var_x / (var_x + epsilon) + b = (1 - A) * mean_x + m = A * x + b + + # mean_A = self.box_filter(A, self.r) / N + # mean_b = self.box_filter(b, self.r) / N + # m = mean_A * x + mean_b + return x * m + + +class AdaConvGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaConvGuidedFilter, self).__init__() + f = esa_channels + + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=f, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=f, + bias=bias, + ) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + y = self.vec_conv(x) + y = self.hor_conv(y) + + sigma = torch.pow(y, 2) + epsilon = self.fc(self.gap(y)) + + weight = sigma / (sigma + epsilon) + + m = weight * x + (1 - weight) + + return x * m diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py new file mode 100644 index 000000000..731a25f75 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: layernorm.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:20 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py new file mode 100644 index 000000000..4260fb7c9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: pixelshuffle.py +# Created Date: Friday July 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 1st July 2022 10:18:39 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch.nn as nn + + +def pixelshuffle_block( + in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False +): + """ + Upsample features according to `upscale_factor`. + """ + padding = kernel_size // 2 + conv = nn.Conv2d( + in_channels, + out_channels * (upscale_factor**2), + kernel_size, + padding=1, + bias=bias, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + return nn.Sequential(*[conv, pixel_shuffle]) diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py index 4d52f05dd..b50db7c24 100644 --- a/comfy_extras/chainner_models/architecture/RRDB.py +++ b/comfy_extras/chainner_models/architecture/RRDB.py @@ -79,6 +79,12 @@ class RRDBNet(nn.Module): self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None @@ -105,11 +111,15 @@ class RRDBNet(nn.Module): out_nc=self.num_filters, upscale_factor=3, act_type=self.act, + c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( - in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] @@ -122,6 +132,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), B.ShortcutBlock( B.sequential( @@ -138,6 +149,7 @@ class RRDBNet(nn.Module): act_type=self.act, mode="CNA", plus=self.plus, + c2x2=c2x2, ) for _ in range(self.num_blocks) ], @@ -149,6 +161,7 @@ class RRDBNet(nn.Module): norm_type=self.norm, act_type=None, mode=self.mode, + c2x2=c2x2, ), ) ), @@ -160,6 +173,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=self.act, + c2x2=c2x2, ), # hr_conv1 B.conv_block( @@ -168,6 +182,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), ) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 214642cc4..d7bc5d227 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -141,6 +141,19 @@ def sequential(*args): ConvMode = Literal["CNA", "NAC", "CNAC"] +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + def conv_block( in_nc: int, out_nc: int, @@ -153,12 +166,17 @@ def conv_block( norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", + c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None @@ -285,6 +303,7 @@ class RRDB(nn.Module): _convtype="Conv2D", _spectral_norm=False, plus=False, + c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( @@ -298,6 +317,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, @@ -310,6 +330,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, @@ -322,6 +343,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) def forward(self, x): @@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module): act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, + c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() @@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, @@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, @@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, @@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) if mode == "CNA": last_act = None @@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=last_act, mode=mode, + c2x2=c2x2, ) def forward(self, x): @@ -499,6 +527,7 @@ def upconv_block( norm_type: str | None = None, act_type="relu", mode="nearest", + c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ @@ -512,5 +541,6 @@ def upconv_block( pad_type=pad_type, norm_type=norm_type, act_type=act_type, + c2x2=c2x2, ) return sequential(upsample, conv) diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 8234ac5d1..2e66e6247 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) + # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel: # MAT elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: model = MAT(state_dict) + # Omni-SR + elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: + model = OmniSR(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 8e2bef47a..1906c0c7f 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT) +PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -22,6 +23,7 @@ PyTorchSRModel = Union[ SwinIR, Swin2SR, HAT, + OmniSR, ] From 9b1396e93a19748dd4c4bb35637638bb0f91b5f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 24 May 2023 14:01:11 -0400 Subject: [PATCH 074/120] Fix issue importing other ui prompts. --- web/scripts/pnginfo.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 8ddb7a1c5..977b5ac2f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) { const embeddings = await api.getEmbeddings(); const opts = parameters .substr(p) + .split("\n")[1] .split(",") .reduce((p, n) => { const s = n.split(":"); From 8b4b0c3188110e1faa8865570637172ab4b60ba1 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 25 May 2023 19:23:47 +0200 Subject: [PATCH 075/120] vecorized bislerp --- comfy/utils.py | 117 +++++++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..cc0e5069a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -46,71 +47,81 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd -#slow and inefficient, should be optimized def bislerp(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width ) - height_scale = (shape[2]) / (height ) + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] - shape[3] = width - shape[2] = height - out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) - def algorithm(in1, in2, t): - dims = in1.shape - val = t + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms - #flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low/low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high/high_weight - - dot_prod = (low_norm*high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm - res *= (low_weight * (1.0-val) + high_weight * val) - return res.reshape(dims) - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - in1 = samples[:,:,y1,x1] - in2 = samples[:,:,y1,x2] - in3 = samples[:,:,y2,x1] - in4 = samples[:,:,y2,x2] + pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif (x1 == x2): - out_value = algorithm(in1, in3, wy) - elif (y1 == y2): - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) - out1[:,:,y_dest,x_dest] = out_value - return out1 + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + + coords_1 = coords_1.expand((n, c, h_new, -1)) + coords_2 = coords_2.expand((n, c, h_new, -1)) + ratios = ratios.expand((n, 1, h_new, -1)) + + pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + return result def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": From e1278fa925cf59350bae76dc3d0c59a0e9564789 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 13:30:59 -0400 Subject: [PATCH 076/120] Support old pytorch versions that don't have weights_only. --- comfy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..d58320b4a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -6,6 +6,10 @@ def load_torch_file(ckpt, safe_load=False): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: From 87ab25fac77ff1d558fea3c02733a463cb1fa013 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:31:27 -0400 Subject: [PATCH 077/120] Do operations in same order as the one it replaces. --- comfy/utils.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 33c1c3dd7..f139fbb27 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -98,29 +98,27 @@ def bislerp(samples, width, height): n,c,h,w = samples.shape h_new, w_new = (height, width) - #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) - coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) - coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) - ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - - pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) - #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - coords_1 = coords_1.expand((n, c, h_new, -1)) - coords_2 = coords_2.expand((n, c, h_new, -1)) - ratios = ratios.expand((n, 1, h_new, -1)) - - pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) From eb4bd7711acec9a2a2d4f1d4dcc1d32e1236c976 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:42:56 -0400 Subject: [PATCH 078/120] Remove einops. --- comfy/utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index f139fbb27..5ed9aaa02 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,6 +1,5 @@ import torch import math -import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -104,12 +103,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) - pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) @@ -117,12 +116,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result def common_upscale(samples, width, height, upscale_method, crop): From 4d1ed829d9a934d9a303a725e325f90934854ac8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 19:33:30 -0500 Subject: [PATCH 079/120] Don't load some model types if weight is zero --- nodes.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nodes.py b/nodes.py index f0a93ebd5..68010f040 100644 --- a/nodes.py +++ b/nodes.py @@ -426,6 +426,9 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) @@ -507,6 +510,9 @@ class ControlNetApply: CATEGORY = "conditioning" def apply_controlnet(self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning, ) + c = [] control_hint = image.movedim(-1,1) for t in conditioning: @@ -613,6 +619,9 @@ class unCLIPConditioning: CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): + if strength == 0: + return (conditioning, ) + c = [] for t in conditioning: o = t[1].copy() From 679bd2845af8e22b2802cf326b99b40a26ba7811 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 May 2023 21:46:11 -0400 Subject: [PATCH 080/120] Safetensors isn't optional anymore. --- folder_paths.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 28f117824..20b461c94 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,7 @@ import os -supported_ckpt_extensions = set(['.ckpt', '.pth']) -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) -try: - import safetensors.torch - supported_ckpt_extensions.add('.safetensors') - supported_pt_extensions.add('.safetensors') -except: - print("Could not import safetensors, safetensors support disabled.") - +supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} From 73e85fb3f4b104053fb1ac5d0aea456e373ea8c8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:00:47 -0500 Subject: [PATCH 081/120] Improve error output for failed nodes --- execution.py | 237 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 204 insertions(+), 33 deletions(-) diff --git a/execution.py b/execution.py index 25f2fcacd..691beb102 100644 --- a/execution.py +++ b/execution.py @@ -297,24 +297,80 @@ def validate_inputs(prompt, item, validated): class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + + errors = [] + valid = True + for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } + } + errors.append(error) + continue + val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + errors.append(error) + continue + o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) - r = validate_inputs(prompt, o_id, validated) - if r[0] == False: - validated[o_id] = r - return r + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type + } + } + errors.append(error) + continue + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + valid = False + continue + except Exception as ex: + typ, _, tb = sys.exc_info() + valid = False + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o_id] = (False, reasons, o_id) + continue else: if type_input == "INT": val = int(val) @@ -328,26 +384,97 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r), unique_id) + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f": {str(r)}" + else: + details += "." + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + errors.append(error) + continue + + if len(errors) > 0 or valid is not True: + ret = (False, errors, unique_id) + else: + ret = (True, [], unique_id) - ret = (True, "", unique_id) validated[unique_id] = ret return ret +def full_type_name(klass): + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + def validate_prompt(prompt): outputs = set() for x in prompt: @@ -356,7 +483,13 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs", [], []) + error = { + "type": "prompt_no_outputs", + "message": "Prompt has no outputs", + "details": "", + "extra_info": {} + } + return (False, error, [], []) good_outputs = set() errors = [] @@ -364,34 +497,72 @@ def validate_prompt(prompt): validated = {} for o in outputs: valid = False - reason = "" + reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] - reason = m[1] - node_id = m[2] - except Exception as e: - print(traceback.format_exc()) + reasons = m[1] + except Exception as ex: + typ, _, tb = sys.exc_info() valid = False - reason = "Parsing error" - node_id = None + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o] = (False, reasons, o) - if valid == True: + if valid is True: good_outputs.add(o) else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] - if node_id is not None: - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + print(f"Failed to validate prompt for output {o}:") + if len(reasons) > 0: + print("* (prompt):") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + errors += [(o, reasons)] + for node_id, result in validated.items(): + valid = result[0] + reasons = result[1] + # If a node upstream has errors, the nodes downstream will also + # be reported as invalid, but there will be no errors attached. + # So don't return those nodes as having errors in the response. + if valid is not True and len(reasons) > 0: + if node_id not in node_errors: + class_type = prompt[node_id]['class_type'] + node_errors[node_id] = { + "errors": reasons, + "dependent_outputs": [], + "class_type": class_type + } + print(f"* {class_type} {node_id}:") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + node_errors[node_id]["dependent_outputs"].append(o) + print("Output will be ignored") if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) + errors_list = [] + for o, errors in errors: + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + errors_list = "\n".join(errors_list) - return (True, "", list(good_outputs), node_errors) + error = { + "type": "prompt_no_good_outputs", + "message": "Prompt has no properly connected outputs", + "details": errors_list, + "extra_info": {} + } + + return (False, error, list(good_outputs), node_errors) + + return (True, None, list(good_outputs), node_errors) class PromptQueue: From cc4d3435d3590288e21f3adfd42f044a7e45fae4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:48:55 -0500 Subject: [PATCH 082/120] Highlight failing nodes/inputs in frontend --- web/scripts/app.js | 74 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..21fe94802 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -771,16 +771,25 @@ export class ComfyApp { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { const res = origDrawNodeShape.apply(this, arguments); + const nodeErrors = self.lastPromptError?.node_errors[node.id]; + let color = null; + let lineWidth = 1; if (node.id === +self.runningNodeId) { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; } + else if (self.lastPromptError != null && nodeErrors?.errors) { + color = "red"; + lineWidth = 2; + } + + self.graphTime = Date.now() if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; - ctx.lineWidth = 1; + ctx.lineWidth = lineWidth; ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) @@ -807,11 +816,28 @@ export class ComfyApp { ctx.stroke(); ctx.strokeStyle = fgcolor; ctx.globalAlpha = 1; + } - if (self.progress) { - ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); - ctx.fillStyle = bgcolor; + if (self.progress && node.id === +self.runningNodeId) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillStyle = bgcolor; + } + + // Highlight inputs that failed validation + if (nodeErrors) { + ctx.lineWidth = 2; + ctx.strokeStyle = "red"; + for (const error of nodeErrors.errors) { + if (error.extra_info && error.extra_info.input_name) { + const inputIndex = node.findInputSlot(error.extra_info.input_name) + if (inputIndex !== -1) { + let pos = node.getConnectionPos(true, inputIndex); + ctx.beginPath(); + ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false) + ctx.stroke(); + } + } } } @@ -1243,6 +1269,31 @@ export class ComfyApp { return { workflow, output }; } + #formatError(error) { + if (error == null) { + return "(unknown error)" + } + else if (typeof error === "string") { + return error; + } + else if (error.stack && error.message) { + return error.toString() + } + else if (error.response) { + let message = error.response.error.message; + if (error.response.error.details) + message += ": " + error.response.error.details; + for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { + message += "\n" + nodeError.class_type + ":" + for (const errorReason of nodeError.errors) { + message += "\n - " + errorReason.message + ": " + errorReason.details + } + } + return message + } + return "(unknown error)" + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1250,8 +1301,10 @@ export class ComfyApp { if (this.#processingQueue) { return; } - + this.#processingQueue = true; + this.lastPromptError = null; + try { while (this.#queueItems.length) { ({ number, batchCount } = this.#queueItems.pop()); @@ -1262,7 +1315,12 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response.error || error.toString()); + const formattedError = this.#formatError(error) + this.ui.dialog.show(formattedError); + if (error.response) { + this.lastPromptError = error.response; + this.canvas.draw(true, true); + } break; } @@ -1360,6 +1418,8 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.lastPromptError = null; + this.graphTime = null } } From c33b7c5549b7b277011e2c3f50215ba466afb205 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:54:13 -0500 Subject: [PATCH 083/120] Improve invalid prompt error message --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 691beb102..66753ff90 100644 --- a/execution.py +++ b/execution.py @@ -554,8 +554,8 @@ def validate_prompt(prompt): errors_list = "\n".join(errors_list) error = { - "type": "prompt_no_good_outputs", - "message": "Prompt has no properly connected outputs", + "type": "prompt_outputs_failed_validation", + "message": "Prompt outputs failed validation", "details": errors_list, "extra_info": {} } From 0d834e3a2ba6272b8cee6503f574c0f06002ddc3 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:59:30 -0500 Subject: [PATCH 084/120] Add missing input name/config --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 66753ff90..632aaa843 100644 --- a/execution.py +++ b/execution.py @@ -365,6 +365,8 @@ def validate_inputs(prompt, item, validated): "message": "Exception when validating node", "details": str(ex), "extra_info": { + "input_name": x, + "input_config": info, "error_type": error_type, "traceback": traceback.format_tb(tb) } From ffec815257ddf2371b880eafd575838210fcea07 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 12:48:06 -0500 Subject: [PATCH 085/120] Send back more information about exceptions that happen during execution --- execution.py | 173 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 50 deletions(-) diff --git a/execution.py b/execution.py index 632aaa843..5ed9ff348 100644 --- a/execution.py +++ b/execution.py @@ -102,13 +102,19 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui +def format_value(x): + if isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -117,22 +123,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -210,6 +258,44 @@ class PromptExecutor: self.old_prompt = {} self.server = server + def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "node_id": error["node_id"], + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "message": error["message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "node_id": error["node_id"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -244,42 +330,29 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] + output_node_id = None + to_execute = [] - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() @@ -359,7 +432,7 @@ def validate_inputs(prompt, item, validated): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", @@ -367,7 +440,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] @@ -507,13 +580,13 @@ def validate_prompt(prompt): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] From 6b2a8a3845972bcff02184aaa8ded6eace8300ad Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:03:41 -0500 Subject: [PATCH 086/120] Show message in the frontend if prompt execution raises an exception --- execution.py | 14 +++++++++----- web/scripts/api.js | 6 ++++++ web/scripts/app.js | 35 ++++++++++++++++++++++++++++++----- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 5ed9ff348..f79c3d351 100644 --- a/execution.py +++ b/execution.py @@ -258,27 +258,31 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + class_type = prompt[node_id]["class_type"] + # First, send back the status to the frontend depending # on the exception type if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), - - "node_id": error["node_id"], } self.server.send_sync("execution_interrupted", mes, self.server.client_id) else: if self.server.client_id is not None: mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), "message": error["message"], "exception_type": error["exception_type"], "traceback": error["traceback"], - "node_id": error["node_id"], "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } @@ -346,7 +350,7 @@ class PromptExecutor: # error was raised success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: - self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) diff --git a/web/scripts/api.js b/web/scripts/api.js index 4f061c358..378165b3a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -88,6 +88,12 @@ class ComfyApi extends EventTarget { case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; default: if (this.#registered.has(msg.type)) { this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 21fe94802..e8ab32cf9 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -784,8 +784,10 @@ export class ComfyApp { color = "red"; lineWidth = 2; } - - self.graphTime = Date.now() + else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) { + color = "#f0f"; + lineWidth = 2; + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; @@ -895,6 +897,17 @@ export class ComfyApp { } }); + api.addEventListener("execution_start", ({ detail }) => { + this.lastExecutionError = null + }); + + api.addEventListener("execution_error", ({ detail }) => { + this.lastExecutionError = detail; + const formattedError = this.#formatExecutionError(detail); + this.ui.dialog.show(formattedError); + this.canvas.draw(true, true); + }); + api.init(); } @@ -1269,7 +1282,7 @@ export class ComfyApp { return { workflow, output }; } - #formatError(error) { + #formatPromptError(error) { if (error == null) { return "(unknown error)" } @@ -1294,6 +1307,18 @@ export class ComfyApp { return "(unknown error)" } + #formatExecutionError(error) { + if (error == null) { + return "(unknown error)" + } + + const traceback = error.traceback.join("") + const nodeId = error.node_id + const nodeType = error.node_type + + return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1315,7 +1340,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - const formattedError = this.#formatError(error) + const formattedError = this.#formatPromptError(error) this.ui.dialog.show(formattedError); if (error.response) { this.lastPromptError = error.response; @@ -1419,7 +1444,7 @@ export class ComfyApp { clean() { this.nodeOutputs = {}; this.lastPromptError = null; - this.graphTime = null + this.lastExecutionError = null; } } From e2d080b6941783e50155f694c11ab0da1b1ae240 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:07:51 -0500 Subject: [PATCH 087/120] Return null for value format --- execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index f79c3d351..9cebce928 100644 --- a/execution.py +++ b/execution.py @@ -103,7 +103,9 @@ def get_output_data(obj, input_data_all): return output, ui def format_value(x): - if isinstance(x, (int, float, bool, str)): + if x is None: + return None + elif isinstance(x, (int, float, bool, str)): return x else: return str(x) From a9e7e237248296c8fe0d79991e0f8c2c0f2cf530 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:11:34 -0500 Subject: [PATCH 088/120] Fix --- execution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 9cebce928..ffea00a8c 100644 --- a/execution.py +++ b/execution.py @@ -499,9 +499,7 @@ def validate_inputs(prompt, item, validated): if r is not True: details = f"{x}" if r is not False: - details += f": {str(r)}" - else: - details += "." + details += f" - {str(r)}" error = { "type": "custom_validation_failed", From 62bdd9d26aba086ffbeedd118140e2806e6f4345 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 16:35:54 -0500 Subject: [PATCH 089/120] Catch typecast errors --- execution.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index ffea00a8c..6af58a673 100644 --- a/execution.py +++ b/execution.py @@ -424,7 +424,8 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "received_type": received_type + "received_type": received_type, + "linked_node": val } } errors.append(error) @@ -440,28 +441,44 @@ def validate_inputs(prompt, item, validated): valid = False exception_type = full_type_name(typ) reasons = [{ - "type": "exception_during_validation", - "message": "Exception when validating node", + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", "details": str(ex), "extra_info": { "input_name": x, "input_config": info, "exception_type": exception_type, - "traceback": traceback.format_tb(tb) + "traceback": traceback.format_tb(tb), + "linked_node": val } }] validated[o_id] = (False, reasons, o_id) continue else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + try: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + errors.append(error) + continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: From 52c9590b7b65dba86e8622f6ad38974bc4045f31 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 01:51:39 -0500 Subject: [PATCH 090/120] Exception message --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 6af58a673..52c264b0f 100644 --- a/execution.py +++ b/execution.py @@ -447,6 +447,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "linked_node": val From 03f2d0a764726641e848ba4e069c8809a502afdf Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 02:02:11 -0500 Subject: [PATCH 091/120] Rename exception message field --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 52c264b0f..1a9a1ff73 100644 --- a/execution.py +++ b/execution.py @@ -171,7 +171,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute error_details = { "node_id": unique_id, - "message": str(ex), + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, @@ -282,7 +282,7 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), - "message": error["message"], + "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], From 00646b0813e4f395725f3013f18b13a46f4d619d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 21:48:49 -0500 Subject: [PATCH 092/120] Bitwise operations for masks --- comfy_extras/nodes_mask.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9916f3b21..9134c24da 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -167,7 +167,7 @@ class MaskComposite: "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract"],), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } @@ -193,6 +193,12 @@ class MaskComposite: output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + elif operation == "or": + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + elif operation == "xor": + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() output = torch.clamp(output, 0.0, 1.0) From ad81fd682a5e5e7c1f258d7c11a000c0dfd07be3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:32:26 -0400 Subject: [PATCH 093/120] Fix issue with cancelling prompt. --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 1a9a1ff73..218a84c36 100644 --- a/execution.py +++ b/execution.py @@ -353,6 +353,7 @@ class PromptExecutor: success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) From f3ac938b4a5c031adb9ee2951f26360d6a2b36de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:42:53 -0400 Subject: [PATCH 094/120] Round the mask values for bitwise operations. --- comfy_extras/nodes_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9134c24da..15377af14 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -194,11 +194,11 @@ class MaskComposite: elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion elif operation == "and": - output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "or": - output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "xor": - output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) From 0fc483dcfdef457b50d3a67e66b4f463e6ef9d62 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 01:52:09 -0400 Subject: [PATCH 095/120] Refactor diffusers model convert code to be able to reuse it. --- comfy/diffusers_convert.py | 107 ----------------------------------- comfy/diffusers_load.py | 111 +++++++++++++++++++++++++++++++++++++ nodes.py | 4 +- 3 files changed, 113 insertions(+), 109 deletions(-) create mode 100644 comfy/diffusers_load.py diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index ceca80305..1eab54d4b 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,14 +1,5 @@ -import json -import os -import yaml - -import folder_paths -from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE -import os.path as osp import re import torch -from safetensors.torch import load_file, save_file # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) - - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' - - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') - - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) - - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 - - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") - - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py new file mode 100644 index 000000000..43877fb83 --- /dev/null +++ b/comfy/diffusers_load.py @@ -0,0 +1,111 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file +import diffusers_convert + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/nodes.py b/nodes.py index 68010f040..90444a92c 100644 --- a/nodes.py +++ b/nodes.py @@ -17,7 +17,7 @@ import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) -import comfy.diffusers_convert +import comfy.diffusers_load import comfy.samplers import comfy.sample import comfy.sd @@ -377,7 +377,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: From a532888846809de7b8890e8beb10ea87edf39d7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 02:02:09 -0400 Subject: [PATCH 096/120] Support VAEs in diffusers format. --- comfy/sd.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c6be900ad..4df149fe1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision from . import gligen +from . import diffusers_convert def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -504,10 +505,16 @@ class VAE: if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") else: - self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + if ckpt_path is not None: + sd = utils.load_torch_file(ckpt_path) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + self.first_stage_model.load_state_dict(sd, strict=False) + self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() From 23ffafeb5d4a25bb5e41c34c9f04a0733643892c Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sun, 28 May 2023 23:31:40 +0900 Subject: [PATCH 097/120] typo fix: field name in error message --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index e8ab32cf9..26670239b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1316,7 +1316,7 @@ export class ComfyApp { const nodeId = error.node_id const nodeType = error.node_type - return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}` } async queuePrompt(number, batchCount = 1) { From b9818eb910b6ce683c38602c9b8fbd3979d97aaf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 02:48:50 -0400 Subject: [PATCH 098/120] Add route to get safetensors metadata: /view_metadata/loras?filename=lora.safetensors --- comfy/utils.py | 9 +++++++++ folder_paths.py | 2 ++ server.py | 25 ++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 5ed9aaa02..4e84e870b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import struct def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -50,6 +51,14 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/folder_paths.py b/folder_paths.py index 20b461c94..19245a617 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -126,11 +126,13 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths folders = folder_names_and_paths[folder_name] + filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + return None def get_filename_list(folder_name): global folder_names_and_paths diff --git a/server.py b/server.py index c0f79cbd5..72c565a63 100644 --- a/server.py +++ b/server.py @@ -22,7 +22,7 @@ except ImportError: import mimetypes from comfy.cli_args import args - +import comfy.utils @web.middleware async def cache_control(request: web.Request, handler): @@ -257,6 +257,29 @@ class PromptServer(): return web.Response(status=404) + @routes.get("/view_metadata/{folder_name}") + async def view_metadata(request): + folder_name = request.match_info.get("folder_name", None) + if folder_name is None: + return web.Response(status=404) + if not "filename" in request.rel_url.query: + return web.Response(status=404) + + filename = request.rel_url.query["filename"] + if not filename.endswith(".safetensors"): + return web.Response(status=404) + + safetensors_path = folder_paths.get_full_path(folder_name, filename) + if safetensors_path is None: + return web.Response(status=404) + out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) + if out is None: + return web.Response(status=404) + dt = json.loads(out) + if not "__metadata__" in dt: + return web.Response(status=404) + return web.json_response(dt["__metadata__"]) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 560e9f7a43242c51da2589a33f659ecd41914b20 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:29:00 -0400 Subject: [PATCH 099/120] Disable repo owner validation in update.py --- .ci/update_windows/update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index c09f29a80..ef9374c44 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'): else: raise AssertionError('Unknown merge analysis result') - +pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) repo = pygit2.Repository(str(sys.argv[1])) ident = pygit2.Signature('comfyui', 'comfy@ui') try: From 08abd838b82ea8d08a7e6f1484140d1694180381 Mon Sep 17 00:00:00 2001 From: "Lt.Dr.Data" Date: Tue, 30 May 2023 15:26:45 +0900 Subject: [PATCH 100/120] HOTFIX: Patched the conflict issue between the Combo Refresh feature and PrimitiveNodes. --- web/scripts/app.js | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 26670239b..64adc3e6a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1424,6 +1424,11 @@ export class ComfyApp { const def = defs[node.type]; + // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, + // and additional work is needed to consider the primitive logic in the refresh logic. + if(!def) + continue; + for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { From eb448dd8e18125b569bea9002f909769678a6c43 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 12:36:41 -0400 Subject: [PATCH 101/120] Auto load model in lowvram if not enough memory. --- comfy/model_management.py | 46 ++++++++++++++++++++++++--------------- comfy/sd.py | 18 +++++++++++++-- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c15323219..10a706793 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 -total_vram_available_mb = -1 -accelerate_enabled = False +lowvram_available = True xpu_available = False directml_enabled = False @@ -31,11 +30,12 @@ if args.directml is not None: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: import torch if directml_enabled: - total_vram = 4097 #TODO + pass #TODO else: try: import intel_extension_for_pytorch as ipex @@ -46,7 +46,7 @@ try: total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: - if total_vram <= 4096: + if lowvram_available and total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: @@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION: if args.lowvram: set_vram_to = VRAMState.LOW_VRAM + lowvram_available = True elif args.novram: set_vram_to = VRAMState.NO_VRAM elif args.highvram: @@ -103,18 +104,18 @@ if args.force_fp32: FORCE_FP32 = True -if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + +if lowvram_available: try: import accelerate - accelerate_enabled = True - vram_state = set_vram_to + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to except Exception as e: import traceback print(traceback.format_exc()) - print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + print("ERROR: LOW VRAM MODE NEEDS accelerate.") + lowvram_available = False - total_vram_available_mb = (total_vram - 1024) // 2 - total_vram_available_mb = int(max(256, total_vram_available_mb)) try: if torch.backends.mps.is_available(): @@ -199,22 +200,33 @@ def load_model_gpu(model): model.unpatch_model() raise e - model.model_patches_to(get_torch_device()) + torch_dev = get_torch_device() + model.model_patches_to(torch_dev) + + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = model.model_size() + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + current_loaded_model = model - if vram_state == VRAMState.CPU: + + if vram_set_state == VRAMState.CPU: pass - elif vram_state == VRAMState.MPS: + elif vram_set_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.to(get_torch_device()) else: - if vram_state == VRAMState.NO_VRAM: + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True diff --git a/comfy/sd.py b/comfy/sd.py index 4df149fe1..ce17994f7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -286,15 +286,29 @@ def model_lora_keys(model, key_map={}): return key_map + class ModelPatcher: - def __init__(self, model): + def __init__(self, model, size=0): + self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} + self.model_size() + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + return size def clone(self): - n = ModelPatcher(self.model) + n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) return n From 2260802d90c41f1475a7bf2960aa018dc25f1001 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 16:44:09 -0400 Subject: [PATCH 102/120] Check if folder_name is valid instead of just throwing exception. --- folder_paths.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 19245a617..fc37e52c7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -125,6 +125,8 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths + if folder_name not in folder_names_and_paths: + return None folders = folder_names_and_paths[folder_name] filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: From 04f4fba013da1f556fc310235d5a30c2bfe682e8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:01:49 -0500 Subject: [PATCH 103/120] Fix litegraph dialog CSS --- web/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/index.html b/web/index.html index bb79433ce..da0adb6c2 100644 --- a/web/index.html +++ b/web/index.html @@ -14,5 +14,5 @@ window.graph = app.graph; - + From 468c27afea29928d7d9fcd208e1137a36118ad13 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:06:17 -0500 Subject: [PATCH 104/120] Fix litegraph dialog z-index/font --- web/style.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/style.css b/web/style.css index 87f096e14..db82887c3 100644 --- a/web/style.css +++ b/web/style.css @@ -289,6 +289,11 @@ button.comfy-queue-btn { /* Context menu */ +.litegraph .dialog { + z-index: 1; + font-family: Arial; +} + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; From 8ef197f02852b65509d6ebe06df8794b96a07f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:26:57 -0400 Subject: [PATCH 105/120] Keep list of filenames and only refresh it when something changes. --- folder_paths.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index fc37e52c7..f3d1b8773 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -31,6 +31,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +filename_list_cache = {} + if not os.path.exists(input_directory): os.makedirs(input_directory) @@ -111,12 +113,18 @@ def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): + if not os.path.isdir(directory): + return [], {} result = [] + dirs = {directory: os.path.getmtime(directory)} for root, subdir, file in os.walk(directory, followlinks=True): for filepath in file: #we os.path,join directory with a blank string to generate a path separator at the end. result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result + for d in subdir: + path = os.path.join(root, d) + dirs[path] = os.path.getmtime(path) + return result, dirs def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -136,13 +144,44 @@ def get_full_path(folder_name, filename): return None -def get_filename_list(folder_name): +def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] + output_folders = {} for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) + files, folders_all = recursive_search(x) + output_list.update(filter_files_extensions(files, folders[1])) + output_folders = {**output_folders, **folders_all} + + return (sorted(list(output_list)), output_folders) + +def cached_filename_list_(folder_name): + global filename_list_cache + global folder_names_and_paths + if folder_name not in filename_list_cache: + return None + out = filename_list_cache[folder_name] + for x in out[1]: + time_modified = out[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + if x not in out[1]: + return None + + return out + +def get_filename_list(folder_name): + out = cached_filename_list_(folder_name) + if out is None: + out = get_filename_list_(folder_name) + global filename_list_cache + filename_list_cache[folder_name] = out + return out[0] def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 1f34bf08f06550fb2f041188b5a01d395240be17 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 22:01:25 +0900 Subject: [PATCH 106/120] To support dynamic custom loading, separate the node registration process based on the defs in the registerNodes function. --- web/scripts/app.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 64adc3e6a..9ecad8489 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,6 +1010,11 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); + this.registerNodesFromDefs(defs); + await this.#invokeExtensionsAsync("registerCustomNodes"); + } + + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -1082,8 +1087,6 @@ export class ComfyApp { LiteGraph.registerNodeType(nodeId, node); node.category = nodeData.category; } - - await this.#invokeExtensionsAsync("registerCustomNodes"); } /** From 8e8d6070f2e80aff0200bb3ad0f31716a98d5739 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 23:26:56 +0900 Subject: [PATCH 107/120] race condition patch --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9ecad8489..8a9c7ca49 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,7 +1010,7 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); - this.registerNodesFromDefs(defs); + await this.registerNodesFromDefs(defs); await this.#invokeExtensionsAsync("registerCustomNodes"); } From 03da8a34265bb333d03a51d7503697b5ede9b335 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 May 2023 13:03:24 -0400 Subject: [PATCH 108/120] This is useless for inference. --- comfy/sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index ce17994f7..fa7bd8d32 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -743,7 +743,7 @@ def load_controlnet(ckpt_path, model=None): use_spatial_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) else: @@ -760,7 +760,7 @@ def load_controlnet(ckpt_path, model=None): use_linear_in_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) if pth: @@ -1045,7 +1045,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o } unet_config = { - "use_checkpoint": True, + "use_checkpoint": False, "image_size": 32, "out_channels": 4, "attention_resolutions": [ From d200fa131420a8871633b7321664db419aab2712 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 31 May 2023 19:00:01 -0500 Subject: [PATCH 109/120] Prevent callers from mutating folder lists --- folder_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index f3d1b8773..e179a28d4 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -181,7 +181,7 @@ def get_filename_list(folder_name): out = get_filename_list_(folder_name) global filename_list_cache filename_list_cache[folder_name] = out - return out[0] + return list(out[0]) def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 94680732d32b4b540251c122aee36df8d37266e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 03:52:51 -0400 Subject: [PATCH 110/120] Empty cache on mps. --- comfy/model_management.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 10a706793..60bcd786b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -389,7 +389,10 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - if xpu_available: + global vram_state + if vram_state == VRAMState.MPS: + torch.mps.empty_cache() + elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda From 5c38958e49efd11b5234cb5ff472d752698c5090 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 04:04:35 -0400 Subject: [PATCH 111/120] Tweak lowvram model memory so it's closer to what it was before. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 60bcd786b..e9af7f3a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -207,7 +207,7 @@ def load_model_gpu(model): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = model.model_size() current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM From 1bbd3f7fe16e6637bba232059d004a5fe7804a30 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 22:15:06 -0500 Subject: [PATCH 112/120] Send back prompt number from prompt/ endpoint --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 72c565a63..0b64df147 100644 --- a/server.py +++ b/server.py @@ -361,7 +361,7 @@ class PromptServer(): prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) - return web.json_response({"prompt_id": prompt_id}) + return web.json_response({"prompt_id": prompt_id, "number": number}) else: print("invalid prompt:", valid[1]) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) From b5dd15c67ad3f4dbdc23811f40a4c121e318bfe9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 23:26:23 -0500 Subject: [PATCH 113/120] System stats endpoint --- comfy/model_management.py | 27 +++++++++++++++++++++++++++ server.py | 24 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e9af7f3a7..3b7b1dbf1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -308,6 +308,33 @@ def pytorch_attention_flash_attention(): return True return False +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + mem_total_torch + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index 0b64df147..acbc88f66 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import torch from PIL import Image from io import BytesIO @@ -23,6 +24,7 @@ except ImportError: import mimetypes from comfy.cli_args import args import comfy.utils +import comfy.model_management @web.middleware async def cache_control(request: web.Request, handler): @@ -280,6 +282,28 @@ class PromptServer(): return web.Response(status=404) return web.json_response(dt["__metadata__"]) + @routes.get("/system_stats") + async def get_queue(request): + device_index = comfy.model_management.get_torch_device() + device = torch.device(device_index) + device_name = comfy.model_management.get_torch_device_name(device_index) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 499641ebf1be190e20624ee352e9dc88884e3df1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 2 Jun 2023 00:14:41 -0500 Subject: [PATCH 114/120] More accurate total --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b7b1dbf1..0ea0c71e5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -328,7 +328,7 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] _, mem_total_cuda = torch.cuda.mem_get_info(dev) mem_total_torch = mem_reserved - mem_total = mem_total_cuda + mem_total_torch + mem_total = mem_total_cuda if torch_total_too: return (mem_total, mem_total_torch) From 67892b5ac584ff8def01a5852246c364f8408d95 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 15:05:25 -0400 Subject: [PATCH 115/120] Refactor and improve model_management code related to free memory. --- comfy/model_management.py | 131 +++++++++++++++++++------------------- server.py | 6 +- 2 files changed, 68 insertions(+), 69 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0ea0c71e5..9c3147d76 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): CPU = 0 @@ -33,28 +34,67 @@ if args.directml is not None: lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - pass #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -128,29 +168,17 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_torch_device_name(device): if hasattr(device, 'type'): - return "{}".format(device.type) - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Using device:", get_torch_device_name(get_torch_device())) + print("Device:", get_torch_device_name(get_torch_device())) except: print("Could not pick default device.") @@ -308,33 +336,6 @@ def pytorch_attention_flash_attention(): return True return False -def get_total_memory(dev=None, torch_total_too=False): - global xpu_available - global directml_enabled - if dev is None: - dev = get_torch_device() - - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total - else: - if directml_enabled: - mem_total = 1024 * 1024 * 1024 #TODO - mem_total_torch = mem_total - elif xpu_available: - mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total - else: - stats = torch.cuda.memory_stats(dev) - mem_reserved = stats['reserved_bytes.all.current'] - _, mem_total_cuda = torch.cuda.mem_get_info(dev) - mem_total_torch = mem_reserved - mem_total = mem_total_cuda - - if torch_total_too: - return (mem_total, mem_total_torch) - else: - return mem_total - def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index acbc88f66..5be822a6f 100644 --- a/server.py +++ b/server.py @@ -7,7 +7,6 @@ import execution import uuid import json import glob -import torch from PIL import Image from io import BytesIO @@ -284,9 +283,8 @@ class PromptServer(): @routes.get("/system_stats") async def get_queue(request): - device_index = comfy.model_management.get_torch_device() - device = torch.device(device_index) - device_name = comfy.model_management.get_torch_device_name(device_index) + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) system_stats = { From 871a86593ae7eb96518d326c83cfded5d41c6fa6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 16:34:47 -0400 Subject: [PATCH 116/120] Smarter filename list caching. --- folder_paths.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index e179a28d4..8cee6afde 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,4 +1,5 @@ import os +import time supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) @@ -154,7 +155,7 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders) + return (sorted(list(output_list)), output_folders, time.perf_counter()) def cached_filename_list_(folder_name): global filename_list_cache @@ -162,6 +163,8 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] + if time.perf_counter() < (out[2] + 0.5): + return out for x in out[1]: time_modified = out[1][x] folder = x From 66e588d837275b26b428f737692357090ad41426 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 16:48:56 -0400 Subject: [PATCH 117/120] Ignore folder path directories that don't exist. --- folder_paths.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 8cee6afde..a1bf1444d 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -173,8 +173,9 @@ def cached_filename_list_(folder_name): folders = folder_names_and_paths[folder_name] for x in folders[0]: - if x not in out[1]: - return None + if os.path.isdir(x): + if x not in out[1]: + return None return out From 700491d81a9faf5370a0c54d869e902bbfc839ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 01:47:21 -0400 Subject: [PATCH 118/120] Implement global average pooling for controlnet. --- comfy/sd.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index fa7bd8d32..336fee4a6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): return torch.cat([tensor] * batched_number, dim=0) class ControlNet: - def __init__(self, control_model, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None @@ -630,6 +630,7 @@ class ControlNet: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None @@ -665,6 +666,9 @@ class ControlNet: key = 'output' index = i x = control[i] + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) @@ -695,7 +699,7 @@ class ControlNet: self.cond_hint = None def copy(self): - c = ControlNet(self.control_model) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c.cond_hint_original = self.cond_hint_original c.strength = self.strength return c @@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None): if use_fp16: control_model = control_model.half() - control = ControlNet(control_model) + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control class T2IAdapter: From 0a5fefd6213e3116359e0738533a9e3b733506c5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:05:37 -0400 Subject: [PATCH 119/120] Cleanups and fixes for model_management.py Hopefully fix regression on MPS and CPU. --- comfy/model_management.py | 63 ++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c3147d76..a492ca6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,16 +4,22 @@ from comfy.cli_args import args import torch class VRAMState(Enum): - CPU = 0 + DISABLED = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 @@ -40,15 +46,25 @@ try: except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + def get_torch_device(): global xpu_available global directml_enabled + global cpu_state if directml_enabled: global directml_device return directml_device - if vram_state == VRAMState.MPS: + if cpu_state == CPUState.MPS: return torch.device("mps") - if vram_state == VRAMState.CPU: + if cpu_state == CPUState.CPU: return torch.device("cpu") else: if xpu_available: @@ -143,8 +159,6 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - - if lowvram_available: try: import accelerate @@ -157,17 +171,15 @@ if lowvram_available: lowvram_available = False -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") + def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": @@ -241,13 +253,9 @@ def load_model_gpu(model): current_loaded_model = model - if vram_set_state == VRAMState.CPU: + if vram_set_state == VRAMState.DISABLED: pass - elif vram_set_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: @@ -263,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -308,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -380,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -417,8 +426,8 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - global vram_state - if vram_state == VRAMState.MPS: + global cpu_state + if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif xpu_available: torch.xpu.empty_cache() From 32f282c861eabcee42fdec702b96ebc8924c9834 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:19:10 -0400 Subject: [PATCH 120/120] Search box style fix. --- web/style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/web/style.css b/web/style.css index db82887c3..47571a16e 100644 --- a/web/style.css +++ b/web/style.css @@ -336,6 +336,7 @@ button.comfy-queue-btn { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; overflow: hidden; + display: block; } .litegraph.litesearchbox input,