From 022a9f271b677291c4a4988397695bd3a91666b5 Mon Sep 17 00:00:00 2001 From: mligaintart <> Date: Wed, 5 Apr 2023 19:52:39 -0400 Subject: [PATCH 001/190] Adds masking to Latent Composite, and provides new masking utilities to allow better compositing. --- comfy_extras/nodes_mask.py | 237 +++++++++++++++++++++++++++++++++++++ nodes.py | 87 ++++++++------ 2 files changed, 291 insertions(+), 33 deletions(-) create mode 100644 comfy_extras/nodes_mask.py diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py new file mode 100644 index 000000000..ba39680a7 --- /dev/null +++ b/comfy_extras/nodes_mask.py @@ -0,0 +1,237 @@ +import torch + +from nodes import MAX_RESOLUTION + +class LatentCompositeMasked: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "destination": ("LATENT",), + "source": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }, + "optional": { + "mask": ("MASK",), + } + } + RETURN_TYPES = ("LATENT",) + FUNCTION = "composite" + + CATEGORY = "latent" + + def composite(self, destination, source, x, y, mask = None): + output = destination.copy() + destination = destination["samples"].clone() + source = source["samples"] + + left, top = (x // 8, y // 8) + right, bottom = (left + source.shape[3], top + source.shape[2],) + + + if mask is None: + mask = torch.ones_like(source) + else: + mask = mask.clone() + mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = mask.repeat((source.shape[0], source.shape[1], 1, 1)) + + # calculate the bounds of the source that will be overlapping the destination + # this prevents the source trying to overwrite latent pixels that are out of bounds + # of the destination + visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,) + + mask = mask[:, :, :visible_height, :visible_width] + inverse_mask = torch.ones_like(mask) - mask + + source_portion = mask * source[:, :, :visible_height, :visible_width] + destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + + destination[:, :, top:bottom, left:right] = source_portion + destination_portion + + output["samples"] = destination + + return (output,) + +class MaskToImage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("IMAGE",) + + FUNCTION = "convert" + + def convert(self, mask): + image = torch.cat([torch.reshape(mask.clone(), [1, mask.shape[0], mask.shape[1], 1,])] * 3, 3) + + return (image,) + +class SolidMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "solid" + + def solid(self, value, width, height): + out = torch.full((height, width), value, dtype=torch.float32, device="cpu") + return (out,) + +class InvertMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "invert" + + def invert(self, mask): + out = 1.0 - mask + return (out,) + +class CropMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "crop" + + def crop(self, mask, x, y, width, height): + out = mask[y:y + height, x:x + width] + return (out,) + +class MaskComposite: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "destination": ("MASK",), + "source": ("MASK",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "operation": (["multiply", "add", "subtract"],), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "combine" + + def combine(self, destination, source, x, y, operation): + output = destination.clone() + + left, top = (x, y,) + right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0])) + visible_width, visible_height = (right - left, bottom - top,) + + source_portion = source[:visible_height, :visible_width] + destination_portion = destination[top:bottom, left:right] + + match operation: + case "multiply": + output[top:bottom, left:right] = destination_portion * source_portion + case "add": + output[top:bottom, left:right] = destination_portion + source_portion + case "subtract": + output[top:bottom, left:right] = destination_portion - source_portion + + output = torch.clamp(output, 0.0, 1.0) + + return (output,) + +class FeatherMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "feather" + + def feather(self, mask, left, top, right, bottom): + output = mask.clone() + + left = min(left, output.shape[1]) + right = min(right, output.shape[1]) + top = min(top, output.shape[0]) + bottom = min(bottom, output.shape[0]) + + for x in range(left): + feather_rate = (x + 1.0) / left + output[:, x] *= feather_rate + + for x in range(right): + feather_rate = (x + 1) / right + output[:, -x] *= feather_rate + + for y in range(top): + feather_rate = (y + 1) / top + output[y, :] *= feather_rate + + for y in range(bottom): + feather_rate = (y + 1) / bottom + output[-y, :] *= feather_rate + + return (output,) + + + +NODE_CLASS_MAPPINGS = { + "LatentCompositeMasked": LatentCompositeMasked, + "MaskToImage": MaskToImage, + "SolidMask": SolidMask, + "InvertMask": InvertMask, + "CropMask": CropMask, + "MaskComposite": MaskComposite, + "FeatherMask": FeatherMask, +} + diff --git a/nodes.py b/nodes.py index 187d54a11..eac232d5f 100644 --- a/nodes.py +++ b/nodes.py @@ -553,44 +553,64 @@ class LatentFlip: class LatentComposite: @classmethod def INPUT_TYPES(s): - return {"required": { "samples_to": ("LATENT",), - "samples_from": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - }} + return { + "required": { + "samples_to": ("LATENT",), + "samples_from": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "composite" CATEGORY = "latent" - def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): - x = x // 8 - y = y // 8 - feather = feather // 8 - samples_out = samples_to.copy() - s = samples_to["samples"].clone() - samples_to = samples_to["samples"] - samples_from = samples_from["samples"] - if feather == 0: - s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] - else: - samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] - mask = torch.ones_like(samples_from) - for t in range(feather): - if y != 0: - mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) + def composite(self, samples_to, samples_from, x, y, feather): + output = samples_to.copy() + destination = samples_to["samples"].clone() + source = samples_from["samples"] - if y + samples_from.shape[2] < samples_to.shape[2]: - mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) - if x != 0: - mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) - if x + samples_from.shape[3] < samples_to.shape[3]: - mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) - rev_mask = torch.ones_like(mask) - mask - s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask - samples_out["samples"] = s - return (samples_out,) + left, top = (x // 8, y // 8) + right, bottom = (left + source.shape[3], top + source.shape[2],) + feather = feather // 8 + + + + # calculate the bounds of the source that will be overlapping the destination + # this prevents the source trying to overwrite latent pixels that are out of bounds + # of the destination + visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,) + + mask = torch.ones_like(source) + + for f in range(feather): + feather_rate = (f + 1.0) / feather + + if left > 0: + mask[:, :, :, f] *= feather_rate + + if right < destination.shape[3] - 1: + mask[:, :, :, -f] *= feather_rate + + if top > 0: + mask[:, :, f, :] *= feather_rate + + if bottom < destination.shape[2] - 1: + mask[:, :, -f, :] *= feather_rate + + mask = mask[:, :, :visible_height, :visible_width] + inverse_mask = torch.ones_like(mask) - mask + + source_portion = mask * source[:, :, :visible_height, :visible_width] + destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + + destination[:, :, top:bottom, left:right] = source_portion + destination_portion + + output["samples"] = destination + + return (output,) class LatentCrop: @classmethod @@ -907,7 +927,7 @@ class LoadImageMask: "channel": (["alpha", "red", "green", "blue"], ),} } - CATEGORY = "image" + CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "load_image" @@ -1114,3 +1134,4 @@ def init_custom_nodes(): load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) From 2dc7257e292cad08876e7f188e2fbb2f2abb6644 Mon Sep 17 00:00:00 2001 From: omar92 Date: Sat, 8 Apr 2023 18:58:47 +0200 Subject: [PATCH 002/190] Allow connect premitive Node to "comfyiUI-nodes that have forceInput setting" --- web/extensions/core/widgetInputs.js | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 865af7763..f4d2d22de 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -233,7 +233,9 @@ app.registerExtension({ // Fires before the link is made allowing us to reject it if it isn't valid // No widget, we cant connect - if (!input.widget) return false; + if (!input.widget) { + if (!(input.type in ComfyWidgets)) return false; + } if (this.outputs[slot].links?.length) { return this.#isValidConnection(input); @@ -252,9 +254,18 @@ app.registerExtension({ const input = theirNode.inputs[link.target_slot]; if (!input) return; - const widget = input.widget; - const { type, linkType } = getWidgetType(widget.config); + var _widget; + if (!input.widget) { + if (!(input.type in ComfyWidgets)) return; + _widget = { "name": input.name, "config": [input.type, {}] }//fake widget + } else { + _widget = input.widget; + } + + const widget = _widget; + const { type, linkType } = getWidgetType(widget.config); + console.log({ "input": input }); // Update our output to restrict to the widget type this.outputs[0].type = linkType; this.outputs[0].name = type; @@ -274,7 +285,7 @@ app.registerExtension({ if (type in ComfyWidgets) { widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget; } else { - widget = this.addWidget(type, "value", null, () => {}, {}); + widget = this.addWidget(type, "value", null, () => { }, {}); } if (node?.widgets && widget) { From 9d095c52f3d9fc65477abae380cf8ba6d8b271dd Mon Sep 17 00:00:00 2001 From: omar92 Date: Sat, 8 Apr 2023 19:05:22 +0200 Subject: [PATCH 003/190] handle double click create primitive widget --- web/extensions/core/widgetInputs.js | 43 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index f4d2d22de..28c5aee1d 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -159,27 +159,31 @@ app.registerExtension({ const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined; const input = this.inputs[slot]; - if (input.widget && !input[ignoreDblClick]) { - const node = LiteGraph.createNode("PrimitiveNode"); - app.graph.add(node); - - // Calculate a position that wont directly overlap another node - const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]]; - while (isNodeAtPos(pos)) { - pos[1] += LiteGraph.NODE_TITLE_HEIGHT; - } - - node.pos = pos; - node.connect(0, this, slot); - node.title = input.name; - - // Prevent adding duplicates due to triple clicking - input[ignoreDblClick] = true; - setTimeout(() => { - delete input[ignoreDblClick]; - }, 300); + if (!input.widget || !input[ignoreDblClick])// Not a widget input or already handled input + { + if (!(input.type in ComfyWidgets)) return r;//also Not a ComfyWidgets input (do nothing) } + // Create a primitive node + const node = LiteGraph.createNode("PrimitiveNode"); + app.graph.add(node); + + // Calculate a position that wont directly overlap another node + const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]]; + while (isNodeAtPos(pos)) { + pos[1] += LiteGraph.NODE_TITLE_HEIGHT; + } + + node.pos = pos; + node.connect(0, this, slot); + node.title = input.name; + + // Prevent adding duplicates due to triple clicking + input[ignoreDblClick] = true; + setTimeout(() => { + delete input[ignoreDblClick]; + }, 300); + return r; }; }, @@ -265,7 +269,6 @@ app.registerExtension({ const widget = _widget; const { type, linkType } = getWidgetType(widget.config); - console.log({ "input": input }); // Update our output to restrict to the widget type this.outputs[0].type = linkType; this.outputs[0].name = type; From e12fb88b1b84e354872c0d761544558479bcfad2 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 11 Apr 2023 16:49:39 -0600 Subject: [PATCH 004/190] Image/mask conversion nodes --- nodes.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/nodes.py b/nodes.py index 14a73bcd7..ecd931d69 100644 --- a/nodes.py +++ b/nodes.py @@ -1059,6 +1059,43 @@ class ImagePadForOutpaint: return (new_image, mask) +class ImageToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue"],), + } + } + + CATEGORY = "image" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, channel): + channels = ["red", "green", "blue"] + mask = torch.select(image[0], 2, channels.index(channel)) + return (mask,) + +class MaskToImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "image" + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mask_to_image" + + def mask_to_image(self, mask): + result = mask[None, :, :, None].expand(-1, -1, -1, 3) + return (result,) NODE_CLASS_MAPPINGS = { "KSampler": KSampler, @@ -1102,6 +1139,8 @@ NODE_CLASS_MAPPINGS = { "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + "ImageToMask": ImageToMask, + "MaskToImage": MaskToImage, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1147,6 +1186,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", + "ImageToMask": "Convert Image to Mask", + "MaskToImage": "Convert Mask to Image", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", From e1d289c1ec6894e15af0b57b6630b853341c61fa Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 11 Apr 2023 20:26:24 -0600 Subject: [PATCH 005/190] use slice instead of torch.select() --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index ecd931d69..815631f58 100644 --- a/nodes.py +++ b/nodes.py @@ -1076,7 +1076,7 @@ class ImageToMask: def image_to_mask(self, image, channel): channels = ["red", "green", "blue"] - mask = torch.select(image[0], 2, channels.index(channel)) + mask = image[0, :, :, channels.index(channel)] return (mask,) class MaskToImage: From e87aa1873f6a01148898b41dc3498ca6f82410d8 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 12 Apr 2023 19:36:35 -0600 Subject: [PATCH 006/190] Add slider setting type --- web/scripts/ui.js | 24 ++++++++++++++++++++++++ web/style.css | 8 ++++++++ 2 files changed, 32 insertions(+) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 09861c440..1c7fdc8a1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -270,6 +270,30 @@ class ComfySettingsDialog extends ComfyDialog { ]), ]); break; + case "slider": + element = $el("div", [ + $el("label", { textContent: name }, [ + $el("input", { + type: "range", + value, + oninput: (e) => { + setter(e.target.value); + e.target.nextElementSibling.value = e.target.value; + }, + ...attrs + }), + $el("input", { + type: "number", + value, + oninput: (e) => { + setter(e.target.value); + e.target.previousElementSibling.value = e.target.value; + }, + ...attrs + }), + ]), + ]); + break; default: console.warn("Unsupported setting type, defaulting to text"); element = $el("div", [ diff --git a/web/style.css b/web/style.css index 34e31726c..e3b445762 100644 --- a/web/style.css +++ b/web/style.css @@ -217,6 +217,14 @@ button.comfy-queue-btn { z-index: 99; } +.comfy-modal.comfy-settings input[type="range"] { + vertical-align: middle; +} + +.comfy-modal.comfy-settings input[type="range"] + input[type="number"] { + width: 3.5em; +} + .comfy-modal input, .comfy-modal select { color: var(--input-text); From 8810e1b4b9e2c14860800a2bdc97d50d1aa2f904 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 12 Apr 2023 21:15:21 -0600 Subject: [PATCH 007/190] Fix indentation --- web/scripts/ui.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 1c7fdc8a1..6cbc9383e 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -270,7 +270,7 @@ class ComfySettingsDialog extends ComfyDialog { ]), ]); break; - case "slider": + case "slider": element = $el("div", [ $el("label", { textContent: name }, [ $el("input", { @@ -278,16 +278,16 @@ class ComfySettingsDialog extends ComfyDialog { value, oninput: (e) => { setter(e.target.value); - e.target.nextElementSibling.value = e.target.value; + e.target.nextElementSibling.value = e.target.value; }, ...attrs }), - $el("input", { + $el("input", { type: "number", value, oninput: (e) => { setter(e.target.value); - e.target.previousElementSibling.value = e.target.value; + e.target.previousElementSibling.value = e.target.value; }, ...attrs }), From 9371924e654128258cc82419e83c2a788a32e2be Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 13 Apr 2023 03:11:17 -0600 Subject: [PATCH 008/190] Move mask conversion to separate file --- comfy_extras/nodes_mask_conversion.py | 54 +++++++++++++++++++++++++++ nodes.py | 42 +-------------------- 2 files changed, 55 insertions(+), 41 deletions(-) create mode 100644 comfy_extras/nodes_mask_conversion.py diff --git a/comfy_extras/nodes_mask_conversion.py b/comfy_extras/nodes_mask_conversion.py new file mode 100644 index 000000000..04dcbd0d9 --- /dev/null +++ b/comfy_extras/nodes_mask_conversion.py @@ -0,0 +1,54 @@ +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +import comfy.utils + +class ImageToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue"],), + } + } + + CATEGORY = "image" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, channel): + channels = ["red", "green", "blue"] + mask = image[0, :, :, channels.index(channel)] + return (mask,) + +class MaskToImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "image" + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mask_to_image" + + def mask_to_image(self, mask): + result = mask[None, :, :, None].expand(-1, -1, -1, 3) + return (result,) + +NODE_CLASS_MAPPINGS = { + "ImageToMask": ImageToMask, + "MaskToImage": MaskToImage, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageToMask": "Convert Image to Mask", + "MaskToImage": "Convert Mask to Image", +} diff --git a/nodes.py b/nodes.py index 325e3ba68..3ed9cf499 100644 --- a/nodes.py +++ b/nodes.py @@ -1061,43 +1061,6 @@ class ImagePadForOutpaint: return (new_image, mask) -class ImageToMask: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue"],), - } - } - - CATEGORY = "image" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): - channels = ["red", "green", "blue"] - mask = image[0, :, :, channels.index(channel)] - return (mask,) - -class MaskToImage: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } - - CATEGORY = "image" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) - return (result,) NODE_CLASS_MAPPINGS = { "KSampler": KSampler, @@ -1141,8 +1104,6 @@ NODE_CLASS_MAPPINGS = { "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, - "ImageToMask": ImageToMask, - "MaskToImage": MaskToImage, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1188,8 +1149,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", @@ -1233,3 +1192,4 @@ def init_custom_nodes(): load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask_conversion.py")) From ff0be60ac4c561c50fedce1ed4a0165d3ef087ce Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 13 Apr 2023 06:38:24 -0400 Subject: [PATCH 009/190] fix comfy list not styled, and light theme border --- web/extensions/core/colorPalette.js | 2 +- web/style.css | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 94bea9ab3..41541a8d8 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -107,7 +107,7 @@ const colorPalettes = { "descrip-text": "#444", "drag-text": "#555", "error-text": "#F44336", - "border-color": "#CCC" + "border-color": "#888" } }, }, diff --git a/web/style.css b/web/style.css index 34e31726c..312fc046a 100644 --- a/web/style.css +++ b/web/style.css @@ -160,9 +160,9 @@ body { .comfy-list { color: var(--descrip-text); - background-color: #333; + background-color: var(--comfy-menu-bg); margin-bottom: 10px; - border-color: #4e4e4e; + border-color: var(--border-color); border-style: solid; } From 501f200d8631b2f5d734cad5aefba8d6f4232937 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 13 Apr 2023 10:38:41 -0400 Subject: [PATCH 010/190] Fix widgets not getting converted correctly in workflows. --- web/scripts/app.js | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 2f5e73220..42addc8c6 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -949,9 +949,13 @@ class ComfyApp { widget.value = widget.value.slice(7); } } + } + if (node.type == "KSampler" || node.type == "KSamplerAdvanced" || node.type == "PrimitiveNode") { if (widget.name == "control_after_generate") { - if (widget.value == true) { + if (widget.value === true) { widget.value = "randomize"; + } else if (widget.value === false) { + widget.value = "fixed"; } } } From 601edaf6ad127c46eae417d024c24d6e9ae310c4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 13 Apr 2023 10:59:38 -0400 Subject: [PATCH 011/190] Add links to new controlnet models to colab. --- notebooks/comfyui_colab.ipynb | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 071a89969..8b5c0badf 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -119,9 +119,20 @@ "\n", "\n", "# ControlNet\n", - "#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_depth-fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_scribble-fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_openpose-fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n", "\n", "\n", "# Controlnet Preprocessor nodes by Fannovel16\n", From 307ef543bf66e5ffd718b3a0b148c72287b65a89 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 13 Apr 2023 10:04:06 -0600 Subject: [PATCH 012/190] Change grid size to slider --- web/extensions/core/snapToGrid.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/snapToGrid.js b/web/extensions/core/snapToGrid.js index 20b245e18..cb5fc154b 100644 --- a/web/extensions/core/snapToGrid.js +++ b/web/extensions/core/snapToGrid.js @@ -9,7 +9,7 @@ app.registerExtension({ app.ui.settings.addSetting({ id: "Comfy.SnapToGrid.GridSize", name: "Grid Size", - type: "number", + type: "slider", attrs: { min: 1, max: 500, From 8489cba1405f222f4675c120aee4a3722affb3f8 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 13 Apr 2023 22:01:01 +0200 Subject: [PATCH 013/190] add unique ID per word/embedding for tokenizer --- comfy/sd1_clip.py | 117 ++++++++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4f51657c3..3dd8262ac 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -224,60 +224,85 @@ class SD1Tokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.embedding_directory = embedding_directory self.max_word_length = 8 + self.embedding_identifier = "embedding:" - def tokenize_with_weights(self, text): + def _try_get_embedding(self, name:str): + ''' + Takes a potential embedding name and tries to retrieve it. + Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. + ''' + embedding_name = name[len(self.embedding_identifier):].strip('\n') + embed = load_embed(embedding_name, self.embedding_directory) + if embed is None: + stripped = embedding_name.strip(',') + if len(stripped) < len(embedding_name): + embed = load_embed(stripped, self.embedding_directory) + return (embed, embedding_name[len(stripped):]) + return (embed, "") + + + def tokenize_with_weights(self, text:str): + ''' + Takes a prompt and converts it to a list of (token, weight, word id) elements. + Tokens can both be integer tokens and pre computed CLIP tensors. + Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. + Returned list has the dimensions NxM where M is the input size of CLIP + ''' text = escape_important(text) parsed_weights = token_weights(text, 1.0) + #tokenize words tokens = [] - for t in parsed_weights: - to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') - while len(to_tokenize) > 0: - word = to_tokenize.pop(0) - temp_tokens = [] - embedding_identifier = "embedding:" - if word.startswith(embedding_identifier) and self.embedding_directory is not None: - embedding_name = word[len(embedding_identifier):].strip('\n') - embed = load_embed(embedding_name, self.embedding_directory) + for weighted_segment, weight in parsed_weights: + to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + #if we find an embedding, deal with the embedding + if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: + embed, leftover = self._try_get_embedding(word) if embed is None: - stripped = embedding_name.strip(',') - if len(stripped) < len(embedding_name): - embed = load_embed(stripped, self.embedding_directory) - if embed is not None: - to_tokenize.insert(0, embedding_name[len(stripped):]) - - if embed is not None: - if len(embed.shape) == 1: - temp_tokens += [(embed, t[1])] - else: - for x in range(embed.shape[0]): - temp_tokens += [(embed[x], t[1])] + print(f"warning, embedding:{word} does not exist, ignoring") else: - print("warning, embedding:{} does not exist, ignoring".format(embedding_name)) - elif len(word) > 0: - tt = self.tokenizer(word)["input_ids"][1:-1] - for x in tt: - temp_tokens += [(x, t[1])] - tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) + if len(embed.shape) == 1: + tokens.append([(embed, weight)]) + else: + tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) + #if we accidentally have leftover text, continue parsing using leftover, else move on to next word + if leftover != "": + word = leftover + else: + continue + #parse word + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) + + #reshape token array to CLIP input size + batched_tokens = [] + batch = [] + batched_tokens.append(batch) + for i, t_group in enumerate(tokens): + #start a new batch if there is not enough room + if len(t_group) + len(batch) > self.max_tokens_per_section: + remaining_length = self.max_tokens_per_section - len(batch) + #fill remaining space depending on length of tokens + if len(t_group) > self.max_word_length: + #put part of group of tokens in the batch + batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + t_group = t_group[remaining_length:] + else: + #filler tokens + batch.extend([(self.end_token, 1.0, 0)] * remaining_length) + batch = [] + batched_tokens.append(batch) + #put current group of tokens in the batch + batch.extend([(t,w,i+1) for t,w in t_group]) + + #fill last batch + batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) + + #add start and end tokens + batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + return batched_tokens - #try not to split words in different sections - if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): - for x in range(tokens_left): - tokens += [(self.end_token, 1.0)] - tokens += temp_tokens - - out_tokens = [] - for x in range(0, len(tokens), self.max_tokens_per_section): - o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))] - o_token += [(self.end_token, 1.0)] - if self.pad_with_end: - o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token)) - else: - o_token +=[(0, 1.0)] * (self.max_length - len(o_token)) - - out_tokens += [o_token] - - return out_tokens def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) From 73175cf58c0371903bf9bec107f0e82c5c4363d0 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 13 Apr 2023 22:06:50 +0200 Subject: [PATCH 014/190] split tokenizer from encoder --- comfy/sd.py | 6 ++++-- nodes.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 2d7ff5ab0..6bd30daf4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -372,10 +372,12 @@ class CLIP: def clip_layer(self, layer_idx): self.layer_idx = layer_idx - def encode(self, text): + def tokenize(self, text): + return self.tokenizer.tokenize_with_weights(text) + + def encode(self, tokens): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) - tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() cond = self.cond_stage_model.encode_token_weights(tokens) diff --git a/nodes.py b/nodes.py index 946c66857..b81d16015 100644 --- a/nodes.py +++ b/nodes.py @@ -44,7 +44,8 @@ class CLIPTextEncode: CATEGORY = "conditioning" def encode(self, clip, text): - return ([[clip.encode(text), {}]], ) + tokens = clip.tokenize(text) + return ([[clip.encode(tokens), {}]], ) class ConditioningCombine: @classmethod From d2337a86fe6fb97ed9d818635083fcf1dc2bafc0 Mon Sep 17 00:00:00 2001 From: Gavroche CryptoRUSH <95258328+CryptoRUSHGav@users.noreply.github.com> Date: Thu, 13 Apr 2023 16:38:02 -0400 Subject: [PATCH 015/190] remove extra semi-colon --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 946c66857..f13626771 100644 --- a/nodes.py +++ b/nodes.py @@ -871,7 +871,7 @@ class SaveImage: "filename": file, "subfolder": subfolder, "type": self.type - }); + }) counter += 1 return { "ui": { "images": results } } From 35a2c790b60f836371f8955c96661e929712619e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 14 Apr 2023 00:12:15 -0400 Subject: [PATCH 016/190] Update comfy_extras/nodes_mask.py Co-authored-by: missionfloyd --- comfy_extras/nodes_mask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index ba39680a7..ab17fc509 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -9,8 +9,8 @@ class LatentCompositeMasked: "required": { "destination": ("LATENT",), "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), }, "optional": { "mask": ("MASK",), @@ -26,6 +26,9 @@ class LatentCompositeMasked: destination = destination["samples"].clone() source = source["samples"] + x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8)) + y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8)) + left, top = (x // 8, y // 8) right, bottom = (left + source.shape[3], top + source.shape[2],) @@ -40,7 +43,7 @@ class LatentCompositeMasked: # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds # of the destination - visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,) + visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] inverse_mask = torch.ones_like(mask) - mask From 1a7cda715b3c01ef89b16c5cc96784ca4efa313c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Apr 2023 00:14:35 -0400 Subject: [PATCH 017/190] Revert LatentComposite. --- nodes.py | 82 +++++++++++++++++++++----------------------------------- 1 file changed, 31 insertions(+), 51 deletions(-) diff --git a/nodes.py b/nodes.py index 661f879ac..6468ac6b8 100644 --- a/nodes.py +++ b/nodes.py @@ -578,64 +578,44 @@ class LatentFlip: class LatentComposite: @classmethod def INPUT_TYPES(s): - return { - "required": { - "samples_to": ("LATENT",), - "samples_from": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - } - } + return {"required": { "samples_to": ("LATENT",), + "samples_from": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} RETURN_TYPES = ("LATENT",) FUNCTION = "composite" CATEGORY = "latent" - def composite(self, samples_to, samples_from, x, y, feather): - output = samples_to.copy() - destination = samples_to["samples"].clone() - source = samples_from["samples"] - - left, top = (x // 8, y // 8) - right, bottom = (left + source.shape[3], top + source.shape[2],) + def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): + x = x // 8 + y = y // 8 feather = feather // 8 + samples_out = samples_to.copy() + s = samples_to["samples"].clone() + samples_to = samples_to["samples"] + samples_from = samples_from["samples"] + if feather == 0: + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + else: + samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + mask = torch.ones_like(samples_from) + for t in range(feather): + if y != 0: + mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) - - - # calculate the bounds of the source that will be overlapping the destination - # this prevents the source trying to overwrite latent pixels that are out of bounds - # of the destination - visible_width, visible_height = (destination.shape[3] - left, destination.shape[2] - top,) - - mask = torch.ones_like(source) - - for f in range(feather): - feather_rate = (f + 1.0) / feather - - if left > 0: - mask[:, :, :, f] *= feather_rate - - if right < destination.shape[3] - 1: - mask[:, :, :, -f] *= feather_rate - - if top > 0: - mask[:, :, f, :] *= feather_rate - - if bottom < destination.shape[2] - 1: - mask[:, :, -f, :] *= feather_rate - - mask = mask[:, :, :visible_height, :visible_width] - inverse_mask = torch.ones_like(mask) - mask - - source_portion = mask * source[:, :, :visible_height, :visible_width] - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] - - destination[:, :, top:bottom, left:right] = source_portion + destination_portion - - output["samples"] = destination - - return (output,) + if y + samples_from.shape[2] < samples_to.shape[2]: + mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) + if x != 0: + mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) + if x + samples_from.shape[3] < samples_to.shape[3]: + mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) + rev_mask = torch.ones_like(mask) - mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask + samples_out["samples"] = s + return (samples_out,) class LatentCrop: @classmethod From f48f0872e2310b1650f798d02e94264cc06afd69 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Apr 2023 00:21:01 -0400 Subject: [PATCH 018/190] Refactor: move nodes_mask_convertion nodes to nodes_mask. --- comfy_extras/nodes_mask.py | 39 +++++++++++++++---- comfy_extras/nodes_mask_conversion.py | 54 --------------------------- nodes.py | 1 - 3 files changed, 31 insertions(+), 63 deletions(-) delete mode 100644 comfy_extras/nodes_mask_conversion.py diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index ab17fc509..60feea0db 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -59,23 +59,41 @@ class LatentCompositeMasked: class MaskToImage: @classmethod - def INPUT_TYPES(cls): + def INPUT_TYPES(s): return { - "required": { - "mask": ("MASK",), - } + "required": { + "mask": ("MASK",), + } } CATEGORY = "mask" RETURN_TYPES = ("IMAGE",) + FUNCTION = "mask_to_image" - FUNCTION = "convert" + def mask_to_image(self, mask): + result = mask[None, :, :, None].expand(-1, -1, -1, 3) + return (result,) - def convert(self, mask): - image = torch.cat([torch.reshape(mask.clone(), [1, mask.shape[0], mask.shape[1], 1,])] * 3, 3) +class ImageToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue"],), + } + } - return (image,) + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, channel): + channels = ["red", "green", "blue"] + mask = image[0, :, :, channels.index(channel)] + return (mask,) class SolidMask: @classmethod @@ -231,6 +249,7 @@ class FeatherMask: NODE_CLASS_MAPPINGS = { "LatentCompositeMasked": LatentCompositeMasked, "MaskToImage": MaskToImage, + "ImageToMask": ImageToMask, "SolidMask": SolidMask, "InvertMask": InvertMask, "CropMask": CropMask, @@ -238,3 +257,7 @@ NODE_CLASS_MAPPINGS = { "FeatherMask": FeatherMask, } +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageToMask": "Convert Image to Mask", + "MaskToImage": "Convert Mask to Image", +} diff --git a/comfy_extras/nodes_mask_conversion.py b/comfy_extras/nodes_mask_conversion.py deleted file mode 100644 index 04dcbd0d9..000000000 --- a/comfy_extras/nodes_mask_conversion.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import torch -import torch.nn.functional as F -from PIL import Image - -import comfy.utils - -class ImageToMask: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue"],), - } - } - - CATEGORY = "image" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): - channels = ["red", "green", "blue"] - mask = image[0, :, :, channels.index(channel)] - return (mask,) - -class MaskToImage: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } - - CATEGORY = "image" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) - return (result,) - -NODE_CLASS_MAPPINGS = { - "ImageToMask": ImageToMask, - "MaskToImage": MaskToImage, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} diff --git a/nodes.py b/nodes.py index aff03dd43..6468ac6b8 100644 --- a/nodes.py +++ b/nodes.py @@ -1193,4 +1193,3 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask_conversion.py")) From d98a4de9eb6b676bfe9c172e7310934148e16dd2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Apr 2023 00:49:19 -0400 Subject: [PATCH 019/190] LatentCompositeMasked: negative x, y don't work. --- comfy_extras/nodes_mask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 60feea0db..4dfb0b93e 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -9,8 +9,8 @@ class LatentCompositeMasked: "required": { "destination": ("LATENT",), "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }, "optional": { "mask": ("MASK",), From 334aab05e56c1441b096333431227dd63002f786 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Apr 2023 13:54:00 -0400 Subject: [PATCH 020/190] Don't stop workflow if loading embedding fails. --- comfy/sd1_clip.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4f51657c3..1f057f753 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -2,6 +2,7 @@ import os from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig import torch +import traceback class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -194,14 +195,21 @@ def load_embed(embedding_name, embedding_directory): embed_path = valid_file - if embed_path.lower().endswith(".safetensors"): - import safetensors.torch - embed = safetensors.torch.load_file(embed_path, device="cpu") - else: - if 'weights_only' in torch.load.__code__.co_varnames: - embed = torch.load(embed_path, weights_only=True, map_location="cpu") + try: + if embed_path.lower().endswith(".safetensors"): + import safetensors.torch + embed = safetensors.torch.load_file(embed_path, device="cpu") else: - embed = torch.load(embed_path, map_location="cpu") + if 'weights_only' in torch.load.__code__.co_varnames: + embed = torch.load(embed_path, weights_only=True, map_location="cpu") + else: + embed = torch.load(embed_path, map_location="cpu") + except Exception as e: + print(traceback.format_exc()) + print() + print("error loading embedding, skipping loading:", embedding_name) + return None + if 'string_to_param' in embed: values = embed['string_to_param'].values() else: From 752f7a162ba728b3ab7b9ce53be73c271da25dd5 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 14 Apr 2023 21:02:45 +0200 Subject: [PATCH 021/190] align behavior with old tokenize function --- comfy/sd1_clip.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3dd8262ac..45bc95269 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -226,12 +226,11 @@ class SD1Tokenizer: self.max_word_length = 8 self.embedding_identifier = "embedding:" - def _try_get_embedding(self, name:str): + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. ''' - embedding_name = name[len(self.embedding_identifier):].strip('\n') embed = load_embed(embedding_name, self.embedding_directory) if embed is None: stripped = embedding_name.strip(',') @@ -259,9 +258,10 @@ class SD1Tokenizer: for word in to_tokenize: #if we find an embedding, deal with the embedding if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: - embed, leftover = self._try_get_embedding(word) + embedding_name = word[len(self.embedding_identifier):].strip('\n') + embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{word} does not exist, ignoring") + print(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) @@ -280,21 +280,21 @@ class SD1Tokenizer: batch = [] batched_tokens.append(batch) for i, t_group in enumerate(tokens): - #start a new batch if there is not enough room - if len(t_group) + len(batch) > self.max_tokens_per_section: - remaining_length = self.max_tokens_per_section - len(batch) - #fill remaining space depending on length of tokens - if len(t_group) > self.max_word_length: - #put part of group of tokens in the batch - batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) - t_group = t_group[remaining_length:] + #determine if we're going to try and keep the tokens in a single batch + is_large = len(t_group) >= self.max_word_length + while len(t_group) > 0: + if len(t_group) + len(batch) > self.max_tokens_per_section: + remaining_length = self.max_tokens_per_section - len(batch) + if is_large: + batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + t_group = t_group[remaining_length:] + else: + batch.extend([(self.end_token, 1.0, 0)] * remaining_length) + batch = [] + batched_tokens.append(batch) else: - #filler tokens - batch.extend([(self.end_token, 1.0, 0)] * remaining_length) - batch = [] - batched_tokens.append(batch) - #put current group of tokens in the batch - batch.extend([(t,w,i+1) for t,w in t_group]) + batch.extend([(t,w,i+1) for t,w in t_group]) + t_group = [] #fill last batch batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) From da115bd78d7c4571dc0747dcb17e280b5c8ff4ea Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 14 Apr 2023 21:16:55 +0200 Subject: [PATCH 022/190] ensure backwards compat with optional args --- comfy/sd.py | 10 +++++++--- comfy/sd1_clip.py | 6 +++++- nodes.py | 3 +-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6bd30daf4..6e54bc60b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -372,12 +372,16 @@ class CLIP: def clip_layer(self, layer_idx): self.layer_idx = layer_idx - def tokenize(self, text): - return self.tokenizer.tokenize_with_weights(text) + def tokenize(self, text, return_word_ids=False): + return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode(self, tokens): + def encode(self, text, from_tokens=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) + if from_tokens: + tokens = text + else: + tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() cond = self.cond_stage_model.encode_token_weights(tokens) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 45bc95269..02e925c8e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -240,7 +240,7 @@ class SD1Tokenizer: return (embed, "") - def tokenize_with_weights(self, text:str): + def tokenize_with_weights(self, text:str, return_word_ids=False): ''' Takes a prompt and converts it to a list of (token, weight, word id) elements. Tokens can both be integer tokens and pre computed CLIP tensors. @@ -301,6 +301,10 @@ class SD1Tokenizer: #add start and end tokens batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + + if not return_word_ids: + batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] + return batched_tokens diff --git a/nodes.py b/nodes.py index b68c8ef43..6468ac6b8 100644 --- a/nodes.py +++ b/nodes.py @@ -44,8 +44,7 @@ class CLIPTextEncode: CATEGORY = "conditioning" def encode(self, clip, text): - tokens = clip.tokenize(text) - return ([[clip.encode(tokens), {}]], ) + return ([[clip.encode(text), {}]], ) class ConditioningCombine: @classmethod From 04d9bc13afd684a5bd4cb637e26972bb5aee43d1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Apr 2023 15:33:43 -0400 Subject: [PATCH 023/190] Safely load pickled embeds that don't load with weights_only=True. --- comfy/sd1_clip.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1f057f753..42c9b4c25 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -3,6 +3,7 @@ import os from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig import torch import traceback +import zipfile class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -171,6 +172,26 @@ def unescape_important(text): text = text.replace("\0\2", "(") return text +def safe_load_embed_zip(embed_path): + with zipfile.ZipFile(embed_path) as myzip: + names = list(filter(lambda a: "data/" in a, myzip.namelist())) + names.reverse() + for n in names: + with myzip.open(n) as myfile: + data = myfile.read() + number = len(data) // 4 + length_embed = 1024 #sd2.x + if number < 768: + continue + if number % 768 == 0: + length_embed = 768 #sd1.x + num_embeds = number // length_embed + embed = torch.frombuffer(data, dtype=torch.float) + out = embed.reshape((num_embeds, length_embed)).clone() + del embed + return out + + def load_embed(embedding_name, embedding_directory): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] @@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory): embed_path = valid_file + embed_out = None + try: if embed_path.lower().endswith(".safetensors"): import safetensors.torch embed = safetensors.torch.load_file(embed_path, device="cpu") else: if 'weights_only' in torch.load.__code__.co_varnames: - embed = torch.load(embed_path, weights_only=True, map_location="cpu") + try: + embed = torch.load(embed_path, weights_only=True, map_location="cpu") + except: + embed_out = safe_load_embed_zip(embed_path) else: embed = torch.load(embed_path, map_location="cpu") except Exception as e: @@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory): print("error loading embedding, skipping loading:", embedding_name) return None - if 'string_to_param' in embed: - values = embed['string_to_param'].values() - else: - values = embed.values() - return next(iter(values)) + if embed_out is None: + if 'string_to_param' in embed: + values = embed['string_to_param'].values() + else: + values = embed.values() + embed_out = next(iter(values)) + return embed_out class SD1Tokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None): From 2525fcd342170c2ba7c624fb25745dce1a60e320 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 02:22:37 -0400 Subject: [PATCH 024/190] Colab update. --- notebooks/comfyui_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 8b5c0badf..c088de89c 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -122,7 +122,7 @@ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n", "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n", From d63705d9199b6905a2a94b2a6180795d34427f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sat, 15 Apr 2023 15:50:51 +0800 Subject: [PATCH 025/190] Support releases all unoccupied cached memory from XPU --- execution.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/execution.py b/execution.py index 79c9a3ac0..9d9ca5f68 100644 --- a/execution.py +++ b/execution.py @@ -10,6 +10,8 @@ import gc import torch import nodes +from model_management import xpu_available + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -206,6 +208,8 @@ class PromptExecutor: if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif xpu_available: + torch.xpu.empty_cache() def validate_inputs(prompt, item): From 5186b3266a8cd8958a5c77e05f3bcfbb24e5bde0 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 10:29:32 +0100 Subject: [PATCH 026/190] Prevent generating bad replacement regex --- web/scripts/pnginfo.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 31f470739..209b562a6 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -131,6 +131,7 @@ export async function importA1111(graph, parameters) { } function replaceEmbeddings(text) { + if(!embeddings.length) return text; return text.replaceAll( new RegExp( "\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b", From 901a8901998cd789c5c03d0e57c9f2110632748c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 10:53:30 +0100 Subject: [PATCH 027/190] Allow combo primitive to connect to multiple inputs --- web/extensions/core/widgetInputs.js | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 2b3603419..cb6ebf09c 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -258,7 +258,7 @@ app.registerExtension({ const input = theirNode.inputs[link.target_slot]; if (!input) return; - + var _widget; if (!input.widget) { if (!(input.type in ComfyWidgets)) return; @@ -333,7 +333,20 @@ app.registerExtension({ const config1 = this.outputs[0].widget.config; const config2 = input.widget.config; - if (config1[0] !== config2[0]) return false; + if (config1[0] instanceof Array) { + // These checks shouldnt actually be necessary as the types should match + // but double checking doesn't hurt + + // New input isnt a combo + if (!(config2[0] instanceof Array)) return false; + // New imput combo has a different size + if (config1[0].length !== config2[0].length) return false; + // New input combo has different elements + if (config1[0].find((v, i) => config2[0][i] !== v)) return false; + } else if (config1[0] !== config2[0]) { + // Configs dont match + return false; + } for (const k in config1[1]) { if (k !== "default") { From 887ea0ba83efdb2cfdb506fdaef10481abf85643 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 10:55:19 +0100 Subject: [PATCH 028/190] style --- web/extensions/core/widgetInputs.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index cb6ebf09c..67a59fb32 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -258,7 +258,7 @@ app.registerExtension({ const input = theirNode.inputs[link.target_slot]; if (!input) return; - + var _widget; if (!input.widget) { if (!(input.type in ComfyWidgets)) return; From 476d543fe80b55b696ab87535c28bcccab667bf9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 10:56:15 -0400 Subject: [PATCH 029/190] Fix for older python. from: https://github.com/comfyanonymous/ComfyUI/discussions/476 --- comfy_extras/nodes_mask.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 4dfb0b93e..131cd6a9f 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -187,13 +187,12 @@ class MaskComposite: source_portion = source[:visible_height, :visible_width] destination_portion = destination[top:bottom, left:right] - match operation: - case "multiply": - output[top:bottom, left:right] = destination_portion * source_portion - case "add": - output[top:bottom, left:right] = destination_portion + source_portion - case "subtract": - output[top:bottom, left:right] = destination_portion - source_portion + if operation == "multiply": + output[top:bottom, left:right] = destination_portion * source_portion + elif operation == "add": + output[top:bottom, left:right] = destination_portion + source_portion + elif operation == "subtract": + output[top:bottom, left:right] = destination_portion - source_portion output = torch.clamp(output, 0.0, 1.0) From deb2b93e797cb345d18e5fd54dff20837fd5ba02 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 11:19:07 -0400 Subject: [PATCH 030/190] Move code to empty gpu cache to model_management.py --- comfy/model_management.py | 9 +++++++++ execution.py | 9 ++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8303cb437..76455e4a2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -307,6 +307,15 @@ def should_use_fp16(): return True +def soft_empty_cache(): + global xpu_available + if xpu_available: + torch.xpu.empty_cache() + elif torch.cuda.is_available(): + if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + #TODO: might be cleaner to put this somewhere else import threading diff --git a/execution.py b/execution.py index 9d9ca5f68..73be6db03 100644 --- a/execution.py +++ b/execution.py @@ -10,7 +10,7 @@ import gc import torch import nodes -from model_management import xpu_available +import comfy.model_management def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -204,12 +204,7 @@ class PromptExecutor: self.server.send_sync("executing", { "node": None }, self.server.client_id) gc.collect() - if torch.cuda.is_available(): - if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - elif xpu_available: - torch.xpu.empty_cache() + comfy.model_management.soft_empty_cache() def validate_inputs(prompt, item): From f5a78658b7ba8b3c278f8f6d79c249c73582df87 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 17:34:46 +0100 Subject: [PATCH 031/190] Fix double click on converted combo widget link --- web/extensions/core/widgetInputs.js | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 2b3603419..88e5fc6f0 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -159,9 +159,11 @@ app.registerExtension({ const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined; const input = this.inputs[slot]; - if (!input.widget || !input[ignoreDblClick])// Not a widget input or already handled input - { - if (!(input.type in ComfyWidgets)) return r;//also Not a ComfyWidgets input (do nothing) + if (!input.widget || !input[ignoreDblClick]) { + // Not a widget input or already handled input + if (!(input.type in ComfyWidgets) && !(input.widget.config?.[0] instanceof Array)) { + return r; //also Not a ComfyWidgets input or combo (do nothing) + } } // Create a primitive node From d0b1b6c6bf60a6f85e742a3340e9fcd9b06d0bde Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:38:21 +0200 Subject: [PATCH 032/190] fixed improper padding --- comfy/sd1_clip.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 02e925c8e..32612cf31 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -247,6 +247,11 @@ class SD1Tokenizer: Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -277,30 +282,33 @@ class SD1Tokenizer: #reshape token array to CLIP input size batched_tokens = [] - batch = [] + batch = [(self.start_token, 1.0, 0)] batched_tokens.append(batch) for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch is_large = len(t_group) >= self.max_word_length + while len(t_group) > 0: - if len(t_group) + len(batch) > self.max_tokens_per_section: - remaining_length = self.max_tokens_per_section - len(batch) + if len(t_group) + len(batch) > self.max_length - 1: + remaining_length = self.max_length - len(batch) - 1 + #break word in two and add end token if is_large: batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + batch.append((self.end_token, 1.0, 0)) t_group = t_group[remaining_length:] + #add end token and pad else: - batch.extend([(self.end_token, 1.0, 0)] * remaining_length) - batch = [] + batch.append((self.end_token, 1.0, 0)) + batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + #start new batch + batch = [(self.start_token, 1.0, 0)] batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] #fill last batch - batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) - - #add start and end tokens - batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] From eb4035c8bd8504531b5b11dac05303d19b42ee05 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 21:40:39 +0100 Subject: [PATCH 033/190] Adds jsdoc for better dev experience --- web/scripts/app.js | 34 +- web/types/comfy.d.ts | 78 ++ web/types/litegraph.d.ts | 1506 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1614 insertions(+), 4 deletions(-) create mode 100644 web/types/comfy.d.ts create mode 100644 web/types/litegraph.d.ts diff --git a/web/scripts/app.js b/web/scripts/app.js index 42addc8c6..940c5ecf1 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -4,27 +4,49 @@ import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js"; -class ComfyApp { /** - * List of {number, batchCount} entries to queue + * @typedef {import("types/comfy").ComfyExtension} ComfyExtension + */ + +export class ComfyApp { + /** + * List of entries to queue + * @type {{number: number, batchCount: number}[]} */ #queueItems = []; /** * If the queue is currently being processed + * @type {boolean} */ #processingQueue = false; constructor() { this.ui = new ComfyUI(this); + + /** + * List of extensions that are registered with the app + * @type {ComfyExtension[]} + */ this.extensions = []; + + /** + * Stores the execution output data for each node + * @type {Record} + */ this.nodeOutputs = {}; + + + /** + * If the shift key on the keyboard is pressed + * @type {boolean} + */ this.shiftDown = false; } /** * Invoke an extension callback - * @param {string} method The extension callback to execute - * @param {...any} args Any arguments to pass to the callback + * @param {keyof ComfyExtension} method The extension callback to execute + * @param {any[]} args Any arguments to pass to the callback * @returns */ #invokeExtensions(method, ...args) { @@ -1120,6 +1142,10 @@ class ComfyApp { } } + /** + * Registers a Comfy web extension with the app + * @param {ComfyExtension} extension + */ registerExtension(extension) { if (!extension.name) { throw new Error("Extensions must have a 'name' property."); diff --git a/web/types/comfy.d.ts b/web/types/comfy.d.ts new file mode 100644 index 000000000..3bb924543 --- /dev/null +++ b/web/types/comfy.d.ts @@ -0,0 +1,78 @@ +import { LGraphNode, IWidget } from "./litegraph"; +import { ComfyApp } from "/scripts/app"; + +export interface ComfyExtension { + /** + * The name of the extension + */ + name: string; + /** + * Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added + * @param app The ComfyUI app instance + */ + init(app: ComfyApp): Promise; + /** + * Allows any additonal setup, called after the application is fully set up and running + * @param app The ComfyUI app instance + */ + setup(app: ComfyApp): Promise; + /** + * Called before nodes are registered with the graph + * @param defs The collection of node definitions, add custom ones or edit existing ones + * @param app The ComfyUI app instance + */ + addCustomNodeDefs(defs: Record, app: ComfyApp): Promise; + /** + * Allows the extension to add custom widgets + * @param app The ComfyUI app instance + * @returns An array of {[widget name]: widget data} + */ + getCustomWidgets( + app: ComfyApp + ): Promise< + Array< + Record { widget?: IWidget; minWidth?: number; minHeight?: number }> + > + >; + /** + * Allows the extension to add additional handling to the node before it is registered with LGraph + * @param nodeType The node class (not an instance) + * @param nodeData The original node object info config object + * @param app The ComfyUI app instance + */ + beforeRegisterNodeDef(nodeType: typeof LGraphNode, nodeData: ComfyObjectInfo, app: ComfyApp): Promise; + /** + * Allows the extension to register additional nodes with LGraph after standard nodes are added + * @param app The ComfyUI app instance + */ + registerCustomNodes(app: ComfyApp): Promise; + /** + * Allows the extension to modify a node that has been reloaded onto the graph. + * If you break something in the backend and want to patch workflows in the frontend + * This is the place to do this + * @param node The node that has been loaded + * @param app The ComfyUI app instance + */ + loadedGraphNode(node: LGraphNode, app: ComfyApp); + /** + * Allows the extension to run code after the constructor of the node + * @param node The node that has been created + * @param app The ComfyUI app instance + */ + nodeCreated(node: LGraphNode, app: ComfyApp); +} + +export type ComfyObjectInfo = { + name: string; + display_name?: string; + description?: string; + category: string; + input?: { + required?: Record; + optional?: Record; + }; + output?: string[]; + output_name: string[]; +}; + +export type ComfyObjectInfoConfig = [string | any[]] | [string | any[], any]; diff --git a/web/types/litegraph.d.ts b/web/types/litegraph.d.ts new file mode 100644 index 000000000..6629e779f --- /dev/null +++ b/web/types/litegraph.d.ts @@ -0,0 +1,1506 @@ +// Type definitions for litegraph.js 0.7.0 +// Project: litegraph.js +// Definitions by: NateScarlet + +export type Vector2 = [number, number]; +export type Vector4 = [number, number, number, number]; +export type widgetTypes = + | "number" + | "slider" + | "combo" + | "text" + | "toggle" + | "button"; +export type SlotShape = + | typeof LiteGraph.BOX_SHAPE + | typeof LiteGraph.CIRCLE_SHAPE + | typeof LiteGraph.ARROW_SHAPE + | typeof LiteGraph.SQUARE_SHAPE + | number; // For custom shapes + +/** https://github.com/jagenjo/litegraph.js/tree/master/guides#node-slots */ +export interface INodeSlot { + name: string; + type: string | -1; + label?: string; + dir?: + | typeof LiteGraph.UP + | typeof LiteGraph.RIGHT + | typeof LiteGraph.DOWN + | typeof LiteGraph.LEFT; + color_on?: string; + color_off?: string; + shape?: SlotShape; + locked?: boolean; + nameLocked?: boolean; +} + +export interface INodeInputSlot extends INodeSlot { + link: LLink["id"] | null; +} +export interface INodeOutputSlot extends INodeSlot { + links: LLink["id"][] | null; +} + +export type WidgetCallback = ( + this: T, + value: T["value"], + graphCanvas: LGraphCanvas, + node: LGraphNode, + pos: Vector2, + event?: MouseEvent +) => void; + +export interface IWidget { + name: string | null; + value: TValue; + options?: TOptions; + type?: widgetTypes; + y?: number; + property?: string; + last_y?: number; + clicked?: boolean; + marker?: boolean; + callback?: WidgetCallback; + /** Called by `LGraphCanvas.drawNodeWidgets` */ + draw?( + ctx: CanvasRenderingContext2D, + node: LGraphNode, + width: number, + posY: number, + height: number + ): void; + /** + * Called by `LGraphCanvas.processNodeWidgets` + * https://github.com/jagenjo/litegraph.js/issues/76 + */ + mouse?( + event: MouseEvent, + pos: Vector2, + node: LGraphNode + ): boolean; + /** Called by `LGraphNode.computeSize` */ + computeSize?(width: number): [number, number]; +} +export interface IButtonWidget extends IWidget { + type: "button"; +} +export interface IToggleWidget + extends IWidget { + type: "toggle"; +} +export interface ISliderWidget + extends IWidget { + type: "slider"; +} +export interface INumberWidget extends IWidget { + type: "number"; +} +export interface IComboWidget + extends IWidget< + string[], + { + values: + | string[] + | ((widget: IComboWidget, node: LGraphNode) => string[]); + } + > { + type: "combo"; +} + +export interface ITextWidget extends IWidget { + type: "text"; +} + +export interface IContextMenuItem { + content: string; + callback?: ContextMenuEventListener; + /** Used as innerHTML for extra child element */ + title?: string; + disabled?: boolean; + has_submenu?: boolean; + submenu?: { + options: ContextMenuItem[]; + } & IContextMenuOptions; + className?: string; +} +export interface IContextMenuOptions { + callback?: ContextMenuEventListener; + ignore_item_callbacks?: Boolean; + event?: MouseEvent | CustomEvent; + parentMenu?: ContextMenu; + autoopen?: boolean; + title?: string; + extra?: any; +} + +export type ContextMenuItem = IContextMenuItem | null; +export type ContextMenuEventListener = ( + value: ContextMenuItem, + options: IContextMenuOptions, + event: MouseEvent, + parentMenu: ContextMenu | undefined, + node: LGraphNode +) => boolean | void; + +export const LiteGraph: { + VERSION: number; + + CANVAS_GRID_SIZE: number; + + NODE_TITLE_HEIGHT: number; + NODE_TITLE_TEXT_Y: number; + NODE_SLOT_HEIGHT: number; + NODE_WIDGET_HEIGHT: number; + NODE_WIDTH: number; + NODE_MIN_WIDTH: number; + NODE_COLLAPSED_RADIUS: number; + NODE_COLLAPSED_WIDTH: number; + NODE_TITLE_COLOR: string; + NODE_TEXT_SIZE: number; + NODE_TEXT_COLOR: string; + NODE_SUBTEXT_SIZE: number; + NODE_DEFAULT_COLOR: string; + NODE_DEFAULT_BGCOLOR: string; + NODE_DEFAULT_BOXCOLOR: string; + NODE_DEFAULT_SHAPE: string; + DEFAULT_SHADOW_COLOR: string; + DEFAULT_GROUP_FONT: number; + + LINK_COLOR: string; + EVENT_LINK_COLOR: string; + CONNECTING_LINK_COLOR: string; + + MAX_NUMBER_OF_NODES: number; //avoid infinite loops + DEFAULT_POSITION: Vector2; //default node position + VALID_SHAPES: ["default", "box", "round", "card"]; //,"circle" + + //shapes are used for nodes but also for slots + BOX_SHAPE: 1; + ROUND_SHAPE: 2; + CIRCLE_SHAPE: 3; + CARD_SHAPE: 4; + ARROW_SHAPE: 5; + SQUARE_SHAPE: 6; + + //enums + INPUT: 1; + OUTPUT: 2; + + EVENT: -1; //for outputs + ACTION: -1; //for inputs + + ALWAYS: 0; + ON_EVENT: 1; + NEVER: 2; + ON_TRIGGER: 3; + + UP: 1; + DOWN: 2; + LEFT: 3; + RIGHT: 4; + CENTER: 5; + + STRAIGHT_LINK: 0; + LINEAR_LINK: 1; + SPLINE_LINK: 2; + + NORMAL_TITLE: 0; + NO_TITLE: 1; + TRANSPARENT_TITLE: 2; + AUTOHIDE_TITLE: 3; + + node_images_path: string; + + debug: boolean; + catch_exceptions: boolean; + throw_errors: boolean; + /** if set to true some nodes like Formula would be allowed to evaluate code that comes from unsafe sources (like node configuration), which could lead to exploits */ + allow_scripts: boolean; + /** node types by string */ + registered_node_types: Record; + /** used for dropping files in the canvas */ + node_types_by_file_extension: Record; + /** node types by class name */ + Nodes: Record; + + /** used to add extra features to the search box */ + searchbox_extras: Record< + string, + { + data: { outputs: string[][]; title: string }; + desc: string; + type: string; + } + >; + + createNode(type: string): T; + /** Register a node class so it can be listed when the user wants to create a new one */ + registerNodeType(type: string, base: { new (): LGraphNode }): void; + /** removes a node type from the system */ + unregisterNodeType(type: string): void; + /** Removes all previously registered node's types. */ + clearRegisteredTypes(): void; + /** + * Create a new node type by passing a function, it wraps it with a proper class and generates inputs according to the parameters of the function. + * Useful to wrap simple methods that do not require properties, and that only process some input to generate an output. + * @param name node name with namespace (p.e.: 'math/sum') + * @param func + * @param param_types an array containing the type of every parameter, otherwise parameters will accept any type + * @param return_type string with the return type, otherwise it will be generic + * @param properties properties to be configurable + */ + wrapFunctionAsNode( + name: string, + func: (...args: any[]) => any, + param_types?: string[], + return_type?: string, + properties?: object + ): void; + + /** + * Adds this method to all node types, existing and to be created + * (You can add it to LGraphNode.prototype but then existing node types wont have it) + */ + addNodeMethod(name: string, func: (...args: any[]) => any): void; + + /** + * Create a node of a given type with a name. The node is not attached to any graph yet. + * @param type full name of the node class. p.e. "math/sin" + * @param name a name to distinguish from other nodes + * @param options to set options + */ + createNode( + type: string, + title: string, + options: object + ): T; + + /** + * Returns a registered node type with a given name + * @param type full name of the node class. p.e. "math/sin" + */ + getNodeType(type: string): LGraphNodeConstructor; + + /** + * Returns a list of node types matching one category + * @method getNodeTypesInCategory + * @param {String} category category name + * @param {String} filter only nodes with ctor.filter equal can be shown + * @return {Array} array with all the node classes + */ + getNodeTypesInCategory( + category: string, + filter: string + ): LGraphNodeConstructor[]; + + /** + * Returns a list with all the node type categories + * @method getNodeTypesCategories + * @param {String} filter only nodes with ctor.filter equal can be shown + * @return {Array} array with all the names of the categories + */ + getNodeTypesCategories(filter: string): string[]; + + /** debug purposes: reloads all the js scripts that matches a wildcard */ + reloadNodes(folder_wildcard: string): void; + + getTime(): number; + LLink: typeof LLink; + LGraph: typeof LGraph; + DragAndScale: typeof DragAndScale; + compareObjects(a: object, b: object): boolean; + distance(a: Vector2, b: Vector2): number; + colorToString(c: string): string; + isInsideRectangle( + x: number, + y: number, + left: number, + top: number, + width: number, + height: number + ): boolean; + growBounding(bounding: Vector4, x: number, y: number): Vector4; + isInsideBounding(p: Vector2, bb: Vector4): boolean; + hex2num(hex: string): [number, number, number]; + num2hex(triplet: [number, number, number]): string; + ContextMenu: typeof ContextMenu; + extendClass(target: A, origin: B): A & B; + getParameterNames(func: string): string[]; +}; + +export type serializedLGraph< + TNode = ReturnType, + // https://github.com/jagenjo/litegraph.js/issues/74 + TLink = [number, number, number, number, number, string], + TGroup = ReturnType +> = { + last_node_id: LGraph["last_node_id"]; + last_link_id: LGraph["last_link_id"]; + nodes: TNode[]; + links: TLink[]; + groups: TGroup[]; + config: LGraph["config"]; + version: typeof LiteGraph.VERSION; +}; + +export declare class LGraph { + static supported_types: string[]; + static STATUS_STOPPED: 1; + static STATUS_RUNNING: 2; + + constructor(o?: object); + + filter: string; + catch_errors: boolean; + /** custom data */ + config: object; + elapsed_time: number; + fixedtime: number; + fixedtime_lapse: number; + globaltime: number; + inputs: any; + iteration: number; + last_link_id: number; + last_node_id: number; + last_update_time: number; + links: Record; + list_of_graphcanvas: LGraphCanvas[]; + outputs: any; + runningtime: number; + starttime: number; + status: typeof LGraph.STATUS_RUNNING | typeof LGraph.STATUS_STOPPED; + + private _nodes: LGraphNode[]; + private _groups: LGraphGroup[]; + private _nodes_by_id: Record; + /** nodes that are executable sorted in execution order */ + private _nodes_executable: + | (LGraphNode & { onExecute: NonNullable }[]) + | null; + /** nodes that contain onExecute */ + private _nodes_in_order: LGraphNode[]; + private _version: number; + + getSupportedTypes(): string[]; + /** Removes all nodes from this graph */ + clear(): void; + /** Attach Canvas to this graph */ + attachCanvas(graphCanvas: LGraphCanvas): void; + /** Detach Canvas to this graph */ + detachCanvas(graphCanvas: LGraphCanvas): void; + /** + * Starts running this graph every interval milliseconds. + * @param interval amount of milliseconds between executions, if 0 then it renders to the monitor refresh rate + */ + start(interval?: number): void; + /** Stops the execution loop of the graph */ + stop(): void; + /** + * Run N steps (cycles) of the graph + * @param num number of steps to run, default is 1 + */ + runStep(num?: number, do_not_catch_errors?: boolean): void; + /** + * Updates the graph execution order according to relevance of the nodes (nodes with only outputs have more relevance than + * nodes with only inputs. + */ + updateExecutionOrder(): void; + /** This is more internal, it computes the executable nodes in order and returns it */ + computeExecutionOrder(only_onExecute: boolean, set_level: any): T; + /** + * Returns all the nodes that could affect this one (ancestors) by crawling all the inputs recursively. + * It doesn't include the node itself + * @return an array with all the LGraphNodes that affect this node, in order of execution + */ + getAncestors(node: LGraphNode): LGraphNode[]; + /** + * Positions every node in a more readable manner + */ + arrange(margin?: number,layout?: string): void; + /** + * Returns the amount of time the graph has been running in milliseconds + * @return number of milliseconds the graph has been running + */ + getTime(): number; + + /** + * Returns the amount of time accumulated using the fixedtime_lapse var. This is used in context where the time increments should be constant + * @return number of milliseconds the graph has been running + */ + getFixedTime(): number; + + /** + * Returns the amount of time it took to compute the latest iteration. Take into account that this number could be not correct + * if the nodes are using graphical actions + * @return number of milliseconds it took the last cycle + */ + getElapsedTime(): number; + /** + * Sends an event to all the nodes, useful to trigger stuff + * @param eventName the name of the event (function to be called) + * @param params parameters in array format + */ + sendEventToAllNodes(eventName: string, params: any[], mode?: any): void; + + sendActionToCanvas(action: any, params: any[]): void; + /** + * Adds a new node instance to this graph + * @param node the instance of the node + */ + add(node: LGraphNode, skip_compute_order?: boolean): void; + /** + * Called when a new node is added + * @param node the instance of the node + */ + onNodeAdded(node: LGraphNode): void; + /** Removes a node from the graph */ + remove(node: LGraphNode): void; + /** Returns a node by its id. */ + getNodeById(id: number): LGraphNode | undefined; + /** + * Returns a list of nodes that matches a class + * @param classObject the class itself (not an string) + * @return a list with all the nodes of this type + */ + findNodesByClass( + classObject: LGraphNodeConstructor + ): T[]; + /** + * Returns a list of nodes that matches a type + * @param type the name of the node type + * @return a list with all the nodes of this type + */ + findNodesByType(type: string): T[]; + /** + * Returns the first node that matches a name in its title + * @param title the name of the node to search + * @return the node or null + */ + findNodeByTitle(title: string): T | null; + /** + * Returns a list of nodes that matches a name + * @param title the name of the node to search + * @return a list with all the nodes with this name + */ + findNodesByTitle(title: string): T[]; + /** + * Returns the top-most node in this position of the canvas + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @param nodes_list a list with all the nodes to search from, by default is all the nodes in the graph + * @return the node at this position or null + */ + getNodeOnPos( + x: number, + y: number, + node_list?: LGraphNode[], + margin?: number + ): T | null; + /** + * Returns the top-most group in that position + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @return the group or null + */ + getGroupOnPos(x: number, y: number): LGraphGroup | null; + + onAction(action: any, param: any): void; + trigger(action: any, param: any): void; + /** Tell this graph it has a global graph input of this type */ + addInput(name: string, type: string, value?: any): void; + /** Assign a data to the global graph input */ + setInputData(name: string, data: any): void; + /** Returns the current value of a global graph input */ + getInputData(name: string): T; + /** Changes the name of a global graph input */ + renameInput(old_name: string, name: string): false | undefined; + /** Changes the type of a global graph input */ + changeInputType(name: string, type: string): false | undefined; + /** Removes a global graph input */ + removeInput(name: string): boolean; + /** Creates a global graph output */ + addOutput(name: string, type: string, value: any): void; + /** Assign a data to the global output */ + setOutputData(name: string, value: string): void; + /** Returns the current value of a global graph output */ + getOutputData(name: string): T; + + /** Renames a global graph output */ + renameOutput(old_name: string, name: string): false | undefined; + /** Changes the type of a global graph output */ + changeOutputType(name: string, type: string): false | undefined; + /** Removes a global graph output */ + removeOutput(name: string): boolean; + triggerInput(name: string, value: any): void; + setCallback(name: string, func: (...args: any[]) => any): void; + beforeChange(info?: LGraphNode): void; + afterChange(info?: LGraphNode): void; + connectionChange(node: LGraphNode): void; + /** returns if the graph is in live mode */ + isLive(): boolean; + /** clears the triggered slot animation in all links (stop visual animation) */ + clearTriggeredSlots(): void; + /* Called when something visually changed (not the graph!) */ + change(): void; + setDirtyCanvas(fg: boolean, bg: boolean): void; + /** Destroys a link */ + removeLink(link_id: number): void; + /** Creates a Object containing all the info about this graph, it can be serialized */ + serialize(): T; + /** + * Configure a graph from a JSON string + * @param data configure a graph from a JSON string + * @returns if there was any error parsing + */ + configure(data: object, keep_old?: boolean): boolean | undefined; + load(url: string): void; +} + +export type SerializedLLink = [number, string, number, number, number, number]; +export declare class LLink { + id: number; + type: string; + origin_id: number; + origin_slot: number; + target_id: number; + target_slot: number; + constructor( + id: number, + type: string, + origin_id: number, + origin_slot: number, + target_id: number, + target_slot: number + ); + configure(o: LLink | SerializedLLink): void; + serialize(): SerializedLLink; +} + +export type SerializedLGraphNode = { + id: T["id"]; + type: T["type"]; + pos: T["pos"]; + size: T["size"]; + flags: T["flags"]; + mode: T["mode"]; + inputs: T["inputs"]; + outputs: T["outputs"]; + title: T["title"]; + properties: T["properties"]; + widgets_values?: IWidget["value"][]; +}; + +/** https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#lgraphnode */ +export declare class LGraphNode { + static title_color: string; + static title: string; + static type: null | string; + static widgets_up: boolean; + constructor(title?: string); + + title: string; + type: null | string; + size: Vector2; + graph: null | LGraph; + graph_version: number; + pos: Vector2; + is_selected: boolean; + mouseOver: boolean; + + id: number; + + //inputs available: array of inputs + inputs: INodeInputSlot[]; + outputs: INodeOutputSlot[]; + connections: any[]; + + //local data + properties: Record; + properties_info: any[]; + + flags: Partial<{ + collapsed: boolean + }>; + + color: string; + bgcolor: string; + boxcolor: string; + shape: + | typeof LiteGraph.BOX_SHAPE + | typeof LiteGraph.ROUND_SHAPE + | typeof LiteGraph.CIRCLE_SHAPE + | typeof LiteGraph.CARD_SHAPE + | typeof LiteGraph.ARROW_SHAPE; + + serialize_widgets: boolean; + skip_list: boolean; + + /** Used in `LGraphCanvas.onMenuNodeMode` */ + mode?: + | typeof LiteGraph.ON_EVENT + | typeof LiteGraph.ON_TRIGGER + | typeof LiteGraph.NEVER + | typeof LiteGraph.ALWAYS; + + /** If set to true widgets do not start after the slots */ + widgets_up: boolean; + /** widgets start at y distance from the top of the node */ + widgets_start_y: number; + /** if you render outside the node, it will be clipped */ + clip_area: boolean; + /** if set to false it wont be resizable with the mouse */ + resizable: boolean; + /** slots are distributed horizontally */ + horizontal: boolean; + /** if true, the node will show the bgcolor as 'red' */ + has_errors?: boolean; + + /** configure a node from an object containing the serialized info */ + configure(info: SerializedLGraphNode): void; + /** serialize the content */ + serialize(): SerializedLGraphNode; + /** Creates a clone of this node */ + clone(): this; + /** serialize and stringify */ + toString(): string; + /** get the title string */ + getTitle(): string; + /** sets the value of a property */ + setProperty(name: string, value: any): void; + /** sets the output data */ + setOutputData(slot: number, data: any): void; + /** sets the output data */ + setOutputDataType(slot: number, type: string): void; + /** + * Retrieves the input data (data traveling through the connection) from one slot + * @param slot + * @param force_update if set to true it will force the connected node of this slot to output data into this link + * @return data or if it is not connected returns undefined + */ + getInputData(slot: number, force_update?: boolean): T; + /** + * Retrieves the input data type (in case this supports multiple input types) + * @param slot + * @return datatype in string format + */ + getInputDataType(slot: number): string; + /** + * Retrieves the input data from one slot using its name instead of slot number + * @param slot_name + * @param force_update if set to true it will force the connected node of this slot to output data into this link + * @return data or if it is not connected returns null + */ + getInputDataByName(slot_name: string, force_update?: boolean): T; + /** tells you if there is a connection in one input slot */ + isInputConnected(slot: number): boolean; + /** tells you info about an input connection (which node, type, etc) */ + getInputInfo( + slot: number + ): { link: number; name: string; type: string | 0 } | null; + /** returns the node connected in the input slot */ + getInputNode(slot: number): LGraphNode | null; + /** returns the value of an input with this name, otherwise checks if there is a property with that name */ + getInputOrProperty(name: string): T; + /** tells you the last output data that went in that slot */ + getOutputData(slot: number): T | null; + /** tells you info about an output connection (which node, type, etc) */ + getOutputInfo( + slot: number + ): { name: string; type: string; links: number[] } | null; + /** tells you if there is a connection in one output slot */ + isOutputConnected(slot: number): boolean; + /** tells you if there is any connection in the output slots */ + isAnyOutputConnected(): boolean; + /** retrieves all the nodes connected to this output slot */ + getOutputNodes(slot: number): LGraphNode[]; + /** Triggers an event in this node, this will trigger any output with the same name */ + trigger(action: string, param: any): void; + /** + * Triggers an slot event in this node + * @param slot the index of the output slot + * @param param + * @param link_id in case you want to trigger and specific output link in a slot + */ + triggerSlot(slot: number, param: any, link_id?: number): void; + /** + * clears the trigger slot animation + * @param slot the index of the output slot + * @param link_id in case you want to trigger and specific output link in a slot + */ + clearTriggeredSlot(slot: number, link_id?: number): void; + /** + * add a new property to this node + * @param name + * @param default_value + * @param type string defining the output type ("vec3","number",...) + * @param extra_info this can be used to have special properties of the property (like values, etc) + */ + addProperty( + name: string, + default_value: any, + type: string, + extra_info?: object + ): T; + /** + * add a new output slot to use in this node + * @param name + * @param type string defining the output type ("vec3","number",...) + * @param extra_info this can be used to have special properties of an output (label, special color, position, etc) + */ + addOutput( + name: string, + type: string | -1, + extra_info?: Partial + ): INodeOutputSlot; + /** + * add a new output slot to use in this node + * @param array of triplets like [[name,type,extra_info],[...]] + */ + addOutputs( + array: [string, string | -1, Partial | undefined][] + ): void; + /** remove an existing output slot */ + removeOutput(slot: number): void; + /** + * add a new input slot to use in this node + * @param name + * @param type string defining the input type ("vec3","number",...), it its a generic one use 0 + * @param extra_info this can be used to have special properties of an input (label, color, position, etc) + */ + addInput( + name: string, + type: string | -1, + extra_info?: Partial + ): INodeInputSlot; + /** + * add several new input slots in this node + * @param array of triplets like [[name,type,extra_info],[...]] + */ + addInputs( + array: [string, string | -1, Partial | undefined][] + ): void; + /** remove an existing input slot */ + removeInput(slot: number): void; + /** + * add an special connection to this node (used for special kinds of graphs) + * @param name + * @param type string defining the input type ("vec3","number",...) + * @param pos position of the connection inside the node + * @param direction if is input or output + */ + addConnection( + name: string, + type: string, + pos: Vector2, + direction: string + ): { + name: string; + type: string; + pos: Vector2; + direction: string; + links: null; + }; + setValue(v: any): void; + /** computes the size of a node according to its inputs and output slots */ + computeSize(): [number, number]; + /** + * https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#node-widgets + * @return created widget + */ + addWidget( + type: T["type"], + name: string, + value: T["value"], + callback?: WidgetCallback | string, + options?: T["options"] + ): T; + + addCustomWidget(customWidget: T): T; + + /** + * returns the bounding of the object, used for rendering purposes + * @return [x, y, width, height] + */ + getBounding(): Vector4; + /** checks if a point is inside the shape of a node */ + isPointInside( + x: number, + y: number, + margin?: number, + skipTitle?: boolean + ): boolean; + /** checks if a point is inside a node slot, and returns info about which slot */ + getSlotInPosition( + x: number, + y: number + ): { + input?: INodeInputSlot; + output?: INodeOutputSlot; + slot: number; + link_pos: Vector2; + }; + /** + * returns the input slot with a given name (used for dynamic slots), -1 if not found + * @param name the name of the slot + * @return the slot (-1 if not found) + */ + findInputSlot(name: string): number; + /** + * returns the output slot with a given name (used for dynamic slots), -1 if not found + * @param name the name of the slot + * @return the slot (-1 if not found) + */ + findOutputSlot(name: string): number; + /** + * connect this node output to the input of another node + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param targetNode the target node + * @param targetSlot the input slot of the target node (could be the number of the slot or the string with the name of the slot, or -1 to connect a trigger) + * @return {Object} the link_info is created, otherwise null + */ + connect( + slot: number | string, + targetNode: LGraphNode, + targetSlot: number | string + ): T | null; + /** + * disconnect one output to an specific node + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param target_node the target node to which this slot is connected [Optional, if not target_node is specified all nodes will be disconnected] + * @return if it was disconnected successfully + */ + disconnectOutput(slot: number | string, targetNode?: LGraphNode): boolean; + /** + * disconnect one input + * @param slot (could be the number of the slot or the string with the name of the slot) + * @return if it was disconnected successfully + */ + disconnectInput(slot: number | string): boolean; + /** + * returns the center of a connection point in canvas coords + * @param is_input true if if a input slot, false if it is an output + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param out a place to store the output, to free garbage + * @return the position + **/ + getConnectionPos( + is_input: boolean, + slot: number | string, + out?: Vector2 + ): Vector2; + /** Force align to grid */ + alignToGrid(): void; + /** Console output */ + trace(msg: string): void; + /** Forces to redraw or the main canvas (LGraphNode) or the bg canvas (links) */ + setDirtyCanvas(fg: boolean, bg: boolean): void; + loadImage(url: string): void; + /** Allows to get onMouseMove and onMouseUp events even if the mouse is out of focus */ + captureInput(v: any): void; + /** Collapse the node to make it smaller on the canvas */ + collapse(force: boolean): void; + /** Forces the node to do not move or realign on Z */ + pin(v?: boolean): void; + localToScreen(x: number, y: number, graphCanvas: LGraphCanvas): Vector2; + + // https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#custom-node-appearance + onDrawBackground?( + ctx: CanvasRenderingContext2D, + canvas: HTMLCanvasElement + ): void; + onDrawForeground?( + ctx: CanvasRenderingContext2D, + canvas: HTMLCanvasElement + ): void; + + // https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#custom-node-behaviour + onMouseDown?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseMove?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseUp?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseEnter?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseLeave?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onKey?(event: KeyboardEvent, pos: Vector2, graphCanvas: LGraphCanvas): void; + + /** Called by `LGraphCanvas.selectNodes` */ + onSelected?(): void; + /** Called by `LGraphCanvas.deselectNode` */ + onDeselected?(): void; + /** Called by `LGraph.runStep` `LGraphNode.getInputData` */ + onExecute?(): void; + /** Called by `LGraph.serialize` */ + onSerialize?(o: SerializedLGraphNode): void; + /** Called by `LGraph.configure` */ + onConfigure?(o: SerializedLGraphNode): void; + /** + * when added to graph (warning: this is called BEFORE the node is configured when loading) + * Called by `LGraph.add` + */ + onAdded?(graph: LGraph): void; + /** + * when removed from graph + * Called by `LGraph.remove` `LGraph.clear` + */ + onRemoved?(): void; + /** + * if returns false the incoming connection will be canceled + * Called by `LGraph.connect` + * @param inputIndex target input slot number + * @param outputType type of output slot + * @param outputSlot output slot object + * @param outputNode node containing the output + * @param outputIndex index of output slot + */ + onConnectInput?( + inputIndex: number, + outputType: INodeOutputSlot["type"], + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number + ): boolean; + /** + * if returns false the incoming connection will be canceled + * Called by `LGraph.connect` + * @param outputIndex target output slot number + * @param inputType type of input slot + * @param inputSlot input slot object + * @param inputNode node containing the input + * @param inputIndex index of input slot + */ + onConnectOutput?( + outputIndex: number, + inputType: INodeInputSlot["type"], + inputSlot: INodeInputSlot, + inputNode: LGraphNode, + inputIndex: number + ): boolean; + + /** + * Called just before connection (or disconnect - if input is linked). + * A convenient place to switch to another input, or create new one. + * This allow for ability to automatically add slots if needed + * @param inputIndex + * @return selected input slot index, can differ from parameter value + */ + onBeforeConnectInput?( + inputIndex: number + ): number; + + /** a connection changed (new one or removed) (LiteGraph.INPUT or LiteGraph.OUTPUT, slot, true if connected, link_info, input_info or output_info ) */ + onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link: LLink, + ioSlot: (INodeOutputSlot | INodeInputSlot) + ): void; + + /** + * if returns false, will abort the `LGraphNode.setProperty` + * Called when a property is changed + * @param property + * @param value + * @param prevValue + */ + onPropertyChanged?(property: string, value: any, prevValue: any): void | boolean; + + /** Called by `LGraphCanvas.processContextMenu` */ + getMenuOptions?(graphCanvas: LGraphCanvas): ContextMenuItem[]; + getSlotMenuOptions?(slot: INodeSlot): ContextMenuItem[]; +} + +export type LGraphNodeConstructor = { + new (): T; +}; + +export type SerializedLGraphGroup = { + title: LGraphGroup["title"]; + bounding: LGraphGroup["_bounding"]; + color: LGraphGroup["color"]; + font: LGraphGroup["font"]; +}; +export declare class LGraphGroup { + title: string; + private _bounding: Vector4; + color: string; + font: string; + + configure(o: SerializedLGraphGroup): void; + serialize(): SerializedLGraphGroup; + move(deltaX: number, deltaY: number, ignoreNodes?: boolean): void; + recomputeInsideNodes(): void; + isPointInside: LGraphNode["isPointInside"]; + setDirtyCanvas: LGraphNode["setDirtyCanvas"]; +} + +export declare class DragAndScale { + constructor(element?: HTMLElement, skipEvents?: boolean); + offset: [number, number]; + scale: number; + max_scale: number; + min_scale: number; + onredraw: Function | null; + enabled: boolean; + last_mouse: Vector2; + element: HTMLElement | null; + visible_area: Vector4; + bindEvents(element: HTMLElement): void; + computeVisibleArea(): void; + onMouse(e: MouseEvent): void; + toCanvasContext(ctx: CanvasRenderingContext2D): void; + convertOffsetToCanvas(pos: Vector2): Vector2; + convertCanvasToOffset(pos: Vector2): Vector2; + mouseDrag(x: number, y: number): void; + changeScale(value: number, zooming_center?: Vector2): void; + changeDeltaScale(value: number, zooming_center?: Vector2): void; + reset(): void; +} + +/** + * This class is in charge of rendering one graph inside a canvas. And provides all the interaction required. + * Valid callbacks are: onNodeSelected, onNodeDeselected, onShowNodePanel, onNodeDblClicked + * + * @param canvas the canvas where you want to render (it accepts a selector in string format or the canvas element itself) + * @param graph + * @param options { skip_rendering, autoresize } + */ +export declare class LGraphCanvas { + static node_colors: Record< + string, + { + color: string; + bgcolor: string; + groupcolor: string; + } + >; + static link_type_colors: Record; + static gradients: object; + static search_limit: number; + + static getFileExtension(url: string): string; + static decodeHTML(str: string): string; + + static onMenuCollapseAll(): void; + static onMenuNodeEdit(): void; + static onShowPropertyEditor( + item: any, + options: any, + e: any, + menu: any, + node: any + ): void; + /** Create menu for `Add Group` */ + static onGroupAdd: ContextMenuEventListener; + /** Create menu for `Add Node` */ + static onMenuAdd: ContextMenuEventListener; + static showMenuNodeOptionalInputs: ContextMenuEventListener; + static showMenuNodeOptionalOutputs: ContextMenuEventListener; + static onShowMenuNodeProperties: ContextMenuEventListener; + static onResizeNode: ContextMenuEventListener; + static onMenuNodeCollapse: ContextMenuEventListener; + static onMenuNodePin: ContextMenuEventListener; + static onMenuNodeMode: ContextMenuEventListener; + static onMenuNodeColors: ContextMenuEventListener; + static onMenuNodeShapes: ContextMenuEventListener; + static onMenuNodeRemove: ContextMenuEventListener; + static onMenuNodeClone: ContextMenuEventListener; + + constructor( + canvas: HTMLCanvasElement | string, + graph?: LGraph, + options?: { + skip_render?: boolean; + autoresize?: boolean; + } + ); + + static active_canvas: HTMLCanvasElement; + + allow_dragcanvas: boolean; + allow_dragnodes: boolean; + /** allow to control widgets, buttons, collapse, etc */ + allow_interaction: boolean; + /** allows to change a connection with having to redo it again */ + allow_reconnect_links: boolean; + /** allow selecting multi nodes without pressing extra keys */ + multi_select: boolean; + /** No effect */ + allow_searchbox: boolean; + always_render_background: boolean; + autoresize?: boolean; + background_image: string; + bgcanvas: HTMLCanvasElement; + bgctx: CanvasRenderingContext2D; + canvas: HTMLCanvasElement; + canvas_mouse: Vector2; + clear_background: boolean; + connecting_node: LGraphNode | null; + connections_width: number; + ctx: CanvasRenderingContext2D; + current_node: LGraphNode | null; + default_connection_color: { + input_off: string; + input_on: string; + output_off: string; + output_on: string; + }; + default_link_color: string; + dirty_area: Vector4 | null; + dirty_bgcanvas?: boolean; + dirty_canvas?: boolean; + drag_mode: boolean; + dragging_canvas: boolean; + dragging_rectangle: Vector4 | null; + ds: DragAndScale; + /** used for transition */ + editor_alpha: number; + filter: any; + fps: number; + frame: number; + graph: LGraph; + highlighted_links: Record; + highquality_render: boolean; + inner_text_font: string; + is_rendering: boolean; + last_draw_time: number; + last_mouse: Vector2; + /** + * Possible duplicated with `last_mouse` + * https://github.com/jagenjo/litegraph.js/issues/70 + */ + last_mouse_position: Vector2; + /** Timestamp of last mouse click, defaults to 0 */ + last_mouseclick: number; + links_render_mode: + | typeof LiteGraph.STRAIGHT_LINK + | typeof LiteGraph.LINEAR_LINK + | typeof LiteGraph.SPLINE_LINK; + live_mode: boolean; + node_capturing_input: LGraphNode | null; + node_dragged: LGraphNode | null; + node_in_panel: LGraphNode | null; + node_over: LGraphNode | null; + node_title_color: string; + node_widget: [LGraphNode, IWidget] | null; + /** Called by `LGraphCanvas.drawBackCanvas` */ + onDrawBackground: + | ((ctx: CanvasRenderingContext2D, visibleArea: Vector4) => void) + | null; + /** Called by `LGraphCanvas.drawFrontCanvas` */ + onDrawForeground: + | ((ctx: CanvasRenderingContext2D, visibleArea: Vector4) => void) + | null; + onDrawOverlay: ((ctx: CanvasRenderingContext2D) => void) | null; + /** Called by `LGraphCanvas.processMouseDown` */ + onMouse: ((event: MouseEvent) => boolean) | null; + /** Called by `LGraphCanvas.drawFrontCanvas` and `LGraphCanvas.drawLinkTooltip` */ + onDrawLinkTooltip: ((ctx: CanvasRenderingContext2D, link: LLink, _this: this) => void) | null; + /** Called by `LGraphCanvas.selectNodes` */ + onNodeMoved: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeSelected` */ + onNodeSelected: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.deselectNode` */ + onNodeDeselected: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeDblClicked` */ + onShowNodePanel: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeDblClicked` */ + onNodeDblClicked: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.selectNodes` */ + onSelectionChange: ((nodes: Record) => void) | null; + /** Called by `LGraphCanvas.showSearchBox` */ + onSearchBox: + | (( + helper: Element, + value: string, + graphCanvas: LGraphCanvas + ) => string[]) + | null; + onSearchBoxSelection: + | ((name: string, event: MouseEvent, graphCanvas: LGraphCanvas) => void) + | null; + pause_rendering: boolean; + render_canvas_border: boolean; + render_collapsed_slots: boolean; + render_connection_arrows: boolean; + render_connections_border: boolean; + render_connections_shadows: boolean; + render_curved_connections: boolean; + render_execution_order: boolean; + render_only_selected: boolean; + render_shadows: boolean; + render_title_colored: boolean; + round_radius: number; + selected_group: null | LGraphGroup; + selected_group_resizing: boolean; + selected_nodes: Record; + show_info: boolean; + title_text_font: string; + /** set to true to render title bar with gradients */ + use_gradients: boolean; + visible_area: DragAndScale["visible_area"]; + visible_links: LLink[]; + visible_nodes: LGraphNode[]; + zoom_modify_alpha: boolean; + + /** clears all the data inside */ + clear(): void; + /** assigns a graph, you can reassign graphs to the same canvas */ + setGraph(graph: LGraph, skipClear?: boolean): void; + /** opens a graph contained inside a node in the current graph */ + openSubgraph(graph: LGraph): void; + /** closes a subgraph contained inside a node */ + closeSubgraph(): void; + /** assigns a canvas */ + setCanvas(canvas: HTMLCanvasElement, skipEvents?: boolean): void; + /** binds mouse, keyboard, touch and drag events to the canvas */ + bindEvents(): void; + /** unbinds mouse events from the canvas */ + unbindEvents(): void; + + /** + * this function allows to render the canvas using WebGL instead of Canvas2D + * this is useful if you plant to render 3D objects inside your nodes, it uses litegl.js for webgl and canvas2DtoWebGL to emulate the Canvas2D calls in webGL + **/ + enableWebGL(): void; + + /** + * marks as dirty the canvas, this way it will be rendered again + * @param fg if the foreground canvas is dirty (the one containing the nodes) + * @param bg if the background canvas is dirty (the one containing the wires) + */ + setDirty(fg: boolean, bg: boolean): void; + + /** + * Used to attach the canvas in a popup + * @return the window where the canvas is attached (the DOM root node) + */ + getCanvasWindow(): Window; + /** starts rendering the content of the canvas when needed */ + startRendering(): void; + /** stops rendering the content of the canvas (to save resources) */ + stopRendering(): void; + + processMouseDown(e: MouseEvent): boolean | undefined; + processMouseMove(e: MouseEvent): boolean | undefined; + processMouseUp(e: MouseEvent): boolean | undefined; + processMouseWheel(e: MouseEvent): boolean | undefined; + + /** returns true if a position (in graph space) is on top of a node little corner box */ + isOverNodeBox(node: LGraphNode, canvasX: number, canvasY: number): boolean; + /** returns true if a position (in graph space) is on top of a node input slot */ + isOverNodeInput( + node: LGraphNode, + canvasX: number, + canvasY: number, + slotPos: Vector2 + ): boolean; + + /** process a key event */ + processKey(e: KeyboardEvent): boolean | undefined; + + copyToClipboard(): void; + pasteFromClipboard(): void; + processDrop(e: DragEvent): void; + checkDropItem(e: DragEvent): void; + processNodeDblClicked(n: LGraphNode): void; + processNodeSelected(n: LGraphNode, e: MouseEvent): void; + processNodeDeselected(node: LGraphNode): void; + + /** selects a given node (or adds it to the current selection) */ + selectNode(node: LGraphNode, add?: boolean): void; + /** selects several nodes (or adds them to the current selection) */ + selectNodes(nodes?: LGraphNode[], add?: boolean): void; + /** removes a node from the current selection */ + deselectNode(node: LGraphNode): void; + /** removes all nodes from the current selection */ + deselectAllNodes(): void; + /** deletes all nodes in the current selection from the graph */ + deleteSelectedNodes(): void; + + /** centers the camera on a given node */ + centerOnNode(node: LGraphNode): void; + /** changes the zoom level of the graph (default is 1), you can pass also a place used to pivot the zoom */ + setZoom(value: number, center: Vector2): void; + /** brings a node to front (above all other nodes) */ + bringToFront(node: LGraphNode): void; + /** sends a node to the back (below all other nodes) */ + sendToBack(node: LGraphNode): void; + /** checks which nodes are visible (inside the camera area) */ + computeVisibleNodes(nodes: LGraphNode[]): LGraphNode[]; + /** renders the whole canvas content, by rendering in two separated canvas, one containing the background grid and the connections, and one containing the nodes) */ + draw(forceFG?: boolean, forceBG?: boolean): void; + /** draws the front canvas (the one containing all the nodes) */ + drawFrontCanvas(): void; + /** draws some useful stats in the corner of the canvas */ + renderInfo(ctx: CanvasRenderingContext2D, x: number, y: number): void; + /** draws the back canvas (the one containing the background and the connections) */ + drawBackCanvas(): void; + /** draws the given node inside the canvas */ + drawNode(node: LGraphNode, ctx: CanvasRenderingContext2D): void; + /** draws graphic for node's slot */ + drawSlotGraphic(ctx: CanvasRenderingContext2D, pos: number[], shape: SlotShape, horizontal: boolean): void; + /** draws the shape of the given node in the canvas */ + drawNodeShape( + node: LGraphNode, + ctx: CanvasRenderingContext2D, + size: [number, number], + fgColor: string, + bgColor: string, + selected: boolean, + mouseOver: boolean + ): void; + /** draws every connection visible in the canvas */ + drawConnections(ctx: CanvasRenderingContext2D): void; + /** + * draws a link between two points + * @param a start pos + * @param b end pos + * @param link the link object with all the link info + * @param skipBorder ignore the shadow of the link + * @param flow show flow animation (for events) + * @param color the color for the link + * @param startDir the direction enum + * @param endDir the direction enum + * @param numSublines number of sublines (useful to represent vec3 or rgb) + **/ + renderLink( + a: Vector2, + b: Vector2, + link: object, + skipBorder: boolean, + flow: boolean, + color?: string, + startDir?: number, + endDir?: number, + numSublines?: number + ): void; + + computeConnectionPoint( + a: Vector2, + b: Vector2, + t: number, + startDir?: number, + endDir?: number + ): void; + + drawExecutionOrder(ctx: CanvasRenderingContext2D): void; + /** draws the widgets stored inside a node */ + drawNodeWidgets( + node: LGraphNode, + posY: number, + ctx: CanvasRenderingContext2D, + activeWidget: object + ): void; + /** process an event on widgets */ + processNodeWidgets( + node: LGraphNode, + pos: Vector2, + event: Event, + activeWidget: object + ): void; + /** draws every group area in the background */ + drawGroups(canvas: any, ctx: CanvasRenderingContext2D): void; + adjustNodesSize(): void; + /** resizes the canvas to a given size, if no size is passed, then it tries to fill the parentNode */ + resize(width?: number, height?: number): void; + /** + * switches to live mode (node shapes are not rendered, only the content) + * this feature was designed when graphs where meant to create user interfaces + **/ + switchLiveMode(transition?: boolean): void; + onNodeSelectionChange(): void; + touchHandler(event: TouchEvent): void; + + showLinkMenu(link: LLink, e: any): false; + prompt( + title: string, + value: any, + callback: Function, + event: any + ): HTMLDivElement; + showSearchBox(event?: MouseEvent): void; + showEditPropertyValue(node: LGraphNode, property: any, options: any): void; + createDialog( + html: string, + options?: { position?: Vector2; event?: MouseEvent } + ): void; + + convertOffsetToCanvas: DragAndScale["convertOffsetToCanvas"]; + convertCanvasToOffset: DragAndScale["convertCanvasToOffset"]; + /** converts event coordinates from canvas2D to graph coordinates */ + convertEventToCanvasOffset(e: MouseEvent): Vector2; + /** adds some useful properties to a mouse event, like the position in graph coordinates */ + adjustMouseEvent(e: MouseEvent): void; + + getCanvasMenuOptions(): ContextMenuItem[]; + getNodeMenuOptions(node: LGraphNode): ContextMenuItem[]; + getGroupMenuOptions(): ContextMenuItem[]; + /** Called by `getCanvasMenuOptions`, replace default options */ + getMenuOptions?(): ContextMenuItem[]; + /** Called by `getCanvasMenuOptions`, append to default options */ + getExtraMenuOptions?(): ContextMenuItem[]; + /** Called when mouse right click */ + processContextMenu(node: LGraphNode, event: Event): void; +} + +declare class ContextMenu { + static trigger( + element: HTMLElement, + event_name: string, + params: any, + origin: any + ): void; + static isCursorOverElement(event: MouseEvent, element: HTMLElement): void; + static closeAllContextMenus(window: Window): void; + constructor(values: ContextMenuItem[], options?: IContextMenuOptions, window?: Window); + options: IContextMenuOptions; + parentMenu?: ContextMenu; + lock: boolean; + current_submenu?: ContextMenu; + addItem( + name: string, + value: ContextMenuItem, + options?: IContextMenuOptions + ): void; + close(e?: MouseEvent, ignore_parent_menu?: boolean): void; + getTopMenu(): void; + getFirstEvent(): void; +} + +declare global { + interface CanvasRenderingContext2D { + /** like rect but rounded corners */ + roundRect( + x: number, + y: number, + width: number, + height: number, + radius: number, + radiusLow: number + ): void; + } + + interface Math { + clamp(v: number, min: number, max: number): number; + } +} From 2dd28d4d20ee4f272db5d674bb229a2fe37dadb5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 21:41:21 +0100 Subject: [PATCH 034/190] style --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 940c5ecf1..1695dcaef 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -4,7 +4,7 @@ import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js"; - /** +/** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension */ From a908e12d23c820b916900f9d9ce2d5ecd507f3a2 Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Sat, 15 Apr 2023 18:18:19 -0400 Subject: [PATCH 035/190] Update nodes.py with new Note node --- nodes.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nodes.py b/nodes.py index 6468ac6b8..d49c830ef 100644 --- a/nodes.py +++ b/nodes.py @@ -510,6 +510,14 @@ class EmptyLatentImage: return ({"samples":latent}, ) +class Note: + @classmethod + def INPUT_TYPES(s): + return {"required": {"text": ("STRING", {"multiline": True})}} + + CATEGORY = "other" + RETURN_TYPES = () + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -1072,6 +1080,7 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, + "Note": Note, "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "PreviewImage": PreviewImage, @@ -1138,6 +1147,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", + "Note": "Note", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", # Image From 81d1f00df32e64053343e863c9c71a5d97761675 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 18:46:58 -0400 Subject: [PATCH 036/190] Some refactoring: from_tokens -> encode_from_tokens --- comfy/sd.py | 10 +++++----- comfy/sd1_clip.py | 6 +++--- comfy/sd2_clip.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6e54bc60b..d6d45fef6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -375,13 +375,9 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode(self, text, from_tokens=False): + def encode_from_tokens(self, tokens): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) - if from_tokens: - tokens = text - else: - tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() cond = self.cond_stage_model.encode_token_weights(tokens) @@ -391,6 +387,10 @@ class CLIP: raise e return cond + def encode(self, text): + tokens = self.tokenizer.tokenize_with_weights(text) + return self.encode_from_tokens(tokens) + class VAE: def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): if config is None: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 97b96953a..7f1217c3d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -315,7 +315,7 @@ class SD1Tokenizer: continue #parse word tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) - + #reshape token array to CLIP input size batched_tokens = [] batch = [(self.start_token, 1.0, 0)] @@ -338,11 +338,11 @@ class SD1Tokenizer: batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) #start new batch batch = [(self.start_token, 1.0, 0)] - batched_tokens.append(batch) + batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] - + #fill last batch batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index fda793eb8..32f202aea 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -1,4 +1,4 @@ -import sd1_clip +from comfy import sd1_clip import torch import os From 73c3e11e83f6fcf1a47b4965fe60b03075e1a762 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 18:55:17 -0400 Subject: [PATCH 037/190] Fix model_management import so it doesn't get executed twice. --- comfy/ldm/modules/attention.py | 2 +- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ldm/modules/sub_quadratic_attention.py | 2 +- comfy/samplers.py | 2 +- comfy/sd.py | 4 ++-- comfy_extras/nodes_upscale_model.py | 2 +- nodes.py | 14 +++++++------- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 92b3eca7c..c83387348 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -9,7 +9,7 @@ from typing import Optional, Any from ldm.modules.diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention -import model_management +from comfy import model_management from . import tomesd diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 788a6fc4f..1599d386e 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -7,7 +7,7 @@ from einops import rearrange from typing import Optional, Any from ldm.modules.attention import MemoryEfficientCrossAttention -import model_management +from comfy import model_management if model_management.xformers_enabled_vae(): import xformers diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index f3c83f387..573cce74f 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -24,7 +24,7 @@ except ImportError: from torch import Tensor from typing import List -import model_management +from comfy import model_management def dynamic_slice( x: Tensor, diff --git a/comfy/samplers.py b/comfy/samplers.py index 93f5d361b..ed36442a9 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -3,7 +3,7 @@ from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch import contextlib -import model_management +from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps diff --git a/comfy/sd.py b/comfy/sd.py index d6d45fef6..9c632e240 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,7 +4,7 @@ import copy import sd1_clip import sd2_clip -import model_management +from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL import yaml @@ -388,7 +388,7 @@ class CLIP: return cond def encode(self, text): - tokens = self.tokenizer.tokenize_with_weights(text) + tokens = self.tokenize(text) return self.encode_from_tokens(tokens) class VAE: diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 6a7d0e516..d8754698c 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -1,6 +1,6 @@ import os from comfy_extras.chainner_models import model_loading -import model_management +from comfy import model_management import torch import comfy.utils import folder_paths diff --git a/nodes.py b/nodes.py index 6468ac6b8..e6ad9434f 100644 --- a/nodes.py +++ b/nodes.py @@ -21,16 +21,16 @@ import comfy.utils import comfy.clip_vision -import model_management +import comfy.model_management import importlib import folder_paths def before_node_execution(): - model_management.throw_exception_if_processing_interrupted() + comfy.model_management.throw_exception_if_processing_interrupted() def interrupt_processing(value=True): - model_management.interrupt_current_processing(value) + comfy.model_management.interrupt_current_processing(value) MAX_RESOLUTION=8192 @@ -241,7 +241,7 @@ class DiffusersLoader: model_path = os.path.join(search_path, model_path) break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: @@ -680,7 +680,7 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] noise_mask = None - device = model_management.get_torch_device() + device = comfy.model_management.get_torch_device() if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") @@ -696,7 +696,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = noise_mask.to(device) real_model = None - model_management.load_model_gpu(model) + comfy.model_management.load_model_gpu(model) real_model = model.model noise = noise.to(device) @@ -726,7 +726,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, control_net_models = [] for x in control_nets: control_net_models += x.get_control_models() - model_management.load_controlnet_gpu(control_net_models) + comfy.model_management.load_controlnet_gpu(control_net_models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) From 6c35ea505efbeb78ffe3c6bfc6b63e68f4290561 Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:48:24 -0400 Subject: [PATCH 038/190] reverting changes --- nodes.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/nodes.py b/nodes.py index d49c830ef..6468ac6b8 100644 --- a/nodes.py +++ b/nodes.py @@ -510,14 +510,6 @@ class EmptyLatentImage: return ({"samples":latent}, ) -class Note: - @classmethod - def INPUT_TYPES(s): - return {"required": {"text": ("STRING", {"multiline": True})}} - - CATEGORY = "other" - RETURN_TYPES = () - class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -1080,7 +1072,6 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, - "Note": Note, "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "PreviewImage": PreviewImage, @@ -1147,7 +1138,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", - "Note": "Note", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", # Image From 9587ea90c82998abc73387aab594ec7217f6d50a Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:50:05 -0400 Subject: [PATCH 039/190] Create noteNode.js --- web/extensions/core/noteNode.js | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 web/extensions/core/noteNode.js diff --git a/web/extensions/core/noteNode.js b/web/extensions/core/noteNode.js new file mode 100644 index 000000000..12428773c --- /dev/null +++ b/web/extensions/core/noteNode.js @@ -0,0 +1,38 @@ +import {app} from "../../scripts/app.js"; +import {ComfyWidgets} from "../../scripts/widgets.js"; +// Node that add notes to your project + +app.registerExtension({ + name: "Comfy.NoteNode", + registerCustomNodes() { + class NoteNode { + color=LGraphCanvas.node_colors.yellow.color; + bgcolor=LGraphCanvas.node_colors.yellow.bgcolor; + groupcolor = LGraphCanvas.node_colors.yellow.groupcolor; + constructor() { + if (!this.properties) { + this.properties = {}; + } + + ComfyWidgets.STRING(this, "", ["", {multiline: true}], app) + // This node is purely frontend and does not impact the resulting prompt so should not be serialized + this.isVirtualNode = true; + } + + + } + + // Load default visibility + + LiteGraph.registerNodeType( + "Note", + Object.assign(NoteNode, { + title_mode: LiteGraph.NORMAL_TITLE, + title: "Note", + collapsable: true, + }) + ); + + NoteNode.category = "utils"; + }, +}); From fb61c75e392ae0a3813955d56fb5aceecacff2e4 Mon Sep 17 00:00:00 2001 From: jwd-dev Date: Sat, 15 Apr 2023 19:58:46 -0400 Subject: [PATCH 040/190] default text property incase we need one. --- web/extensions/core/noteNode.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/noteNode.js b/web/extensions/core/noteNode.js index 12428773c..1412d4373 100644 --- a/web/extensions/core/noteNode.js +++ b/web/extensions/core/noteNode.js @@ -12,9 +12,10 @@ app.registerExtension({ constructor() { if (!this.properties) { this.properties = {}; + this.properties.text=""; } - ComfyWidgets.STRING(this, "", ["", {multiline: true}], app) + ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app) // This node is purely frontend and does not impact the resulting prompt so should not be serialized this.isVirtualNode = true; } From 8cd170f40daa635ad17c29fab12296cb5936df69 Mon Sep 17 00:00:00 2001 From: jwd-dev Date: Sat, 15 Apr 2023 20:16:28 -0400 Subject: [PATCH 041/190] node serialization --- web/extensions/core/noteNode.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/noteNode.js b/web/extensions/core/noteNode.js index 1412d4373..8d89054e9 100644 --- a/web/extensions/core/noteNode.js +++ b/web/extensions/core/noteNode.js @@ -16,8 +16,10 @@ app.registerExtension({ } ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app) - // This node is purely frontend and does not impact the resulting prompt so should not be serialized + + this.serialize_widgets = true; this.isVirtualNode = true; + } From bc16b70bdef76d118f055c023279a4b0d4ce16a7 Mon Sep 17 00:00:00 2001 From: Karun Date: Sun, 16 Apr 2023 01:25:11 -0400 Subject: [PATCH 042/190] Adds several keybinds that interact with ComfyUI (#491) * adds keybinds that interact w/ comfy menu * adds remaining keybinds * adds keybinds to readme and converts to table * ctrl s and o save and open workflow * moves keybinds to sep file, update readme * remap load default, support keycodes * update keybinds table, prepends comfy to ids * escape exits out of modals * modifier keys also use map * adds setting for filename prompt * better handle filename prompt Co-authored-by: missionfloyd --- README.md | 30 +++++++++---- web/extensions/core/keybinds.js | 76 +++++++++++++++++++++++++++++++++ web/scripts/app.js | 6 --- web/scripts/ui.js | 32 +++++++++++--- 4 files changed, 124 insertions(+), 20 deletions(-) create mode 100644 web/extensions/core/keybinds.js diff --git a/README.md b/README.md index 77d979ac3..f610f9497 100644 --- a/README.md +++ b/README.md @@ -32,14 +32,28 @@ This ui will let you design and execute advanced stable diffusion pipelines usin Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) ## Shortcuts -- **Ctrl + A** select all nodes -- **Ctrl + M** mute/unmute selected nodes -- **Delete** or **Backspace** delete selected nodes -- **Space** Holding space key while moving the cursor moves the canvas around. It works when holding the mouse button down so it is easier to connect different nodes when the canvas gets too large. -- **Ctrl/Shift + Click** Add clicked node to selection. -- **Ctrl + C/Ctrl + V** - Copy and paste selected nodes, without maintaining the connection to the outputs of unselected nodes. -- **Ctrl + C/Ctrl + Shift + V** - Copy and paste selected nodes, and maintaining the connection from the outputs of unselected nodes to the inputs of the newly pasted nodes. -- Holding **Shift** and drag selected nodes - Move multiple selected nodes at the same time. + +| Keybind | Explanation | +| - | - | +| Ctrl + Enter | Queue up current graph for generation | +| Ctrl + Shift + Enter | Queue up current graph as first for generation | +| Ctrl + S | Save workflow | +| Ctrl + O | Load workflow | +| Ctrl + A | Select all nodes | +| Ctrl + M | Mute/unmute selected nodes | +| Delete/Backspace | Delete selected nodes | +| Ctrl + Delete/Backspace | Delete the current graph | +| Space | Move the canvas around when held and moving the cursor | +| Ctrl/Shift + Click | Add clicked node to selection | +| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | +| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | +| Shift + Drag | Move multiple selected nodes at the same time | +| Ctrl + D | Load default graph | +| Q | Toggle visibility of the queue | +| H | Toggle visibility of history | +| R | Refresh graph | + +Ctrl can also be replaced with Cmd instead for MacOS users # Installing diff --git a/web/extensions/core/keybinds.js b/web/extensions/core/keybinds.js new file mode 100644 index 000000000..1825007a6 --- /dev/null +++ b/web/extensions/core/keybinds.js @@ -0,0 +1,76 @@ +import { app } from "/scripts/app.js"; + +const id = "Comfy.Keybinds"; +app.registerExtension({ + name: id, + init() { + const keybindListener = function(event) { + const target = event.composedPath()[0]; + + if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") { + return; + } + + const modifierPressed = event.ctrlKey || event.metaKey; + + // Queue prompt using ctrl or command + enter + if (modifierPressed && (event.key === "Enter" || event.keyCode === 13 || event.keyCode === 10)) { + app.queuePrompt(event.shiftKey ? -1 : 0); + return; + } + + const modifierKeyIdMap = { + "s": "#comfy-save-button", + 83: "#comfy-save-button", + "o": "#comfy-file-input", + 79: "#comfy-file-input", + "Backspace": "#comfy-clear-button", + 8: "#comfy-clear-button", + "Delete": "#comfy-clear-button", + 46: "#comfy-clear-button", + "d": "#comfy-load-default-button", + 68: "#comfy-load-default-button", + }; + + const modifierKeybindId = modifierKeyIdMap[event.key] || modifierKeyIdMap[event.keyCode]; + if (modifierPressed && modifierKeybindId) { + event.preventDefault(); + + const elem = document.querySelector(modifierKeybindId); + elem.click(); + return; + } + + // Finished Handling all modifier keybinds, now handle the rest + if (event.ctrlKey || event.altKey || event.metaKey) { + return; + } + + // Close out of modals using escape + if (event.key === "Escape" || event.keyCode === 27) { + const modals = document.querySelectorAll(".comfy-modal"); + const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none"); + if (modal) { + modal.style.display = "none"; + } + } + + const keyIdMap = { + "q": "#comfy-view-queue-button", + 81: "#comfy-view-queue-button", + "h": "#comfy-view-history-button", + 72: "#comfy-view-history-button", + "r": "#comfy-refresh-button", + 82: "#comfy-refresh-button", + }; + + const buttonId = keyIdMap[event.key] || keyIdMap[event.keyCode]; + if (buttonId) { + const button = document.querySelector(buttonId); + button.click(); + } + } + + window.addEventListener("keydown", keybindListener, true); + } +}); diff --git a/web/scripts/app.js b/web/scripts/app.js index 1695dcaef..f158f3457 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -35,7 +35,6 @@ export class ComfyApp { */ this.nodeOutputs = {}; - /** * If the shift key on the keyboard is pressed * @type {boolean} @@ -713,11 +712,6 @@ export class ComfyApp { #addKeyboardHandler() { window.addEventListener("keydown", (e) => { this.shiftDown = e.shiftKey; - - // Queue prompt using ctrl or command + enter - if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) { - this.queuePrompt(e.shiftKey ? -1 : 0); - } }); window.addEventListener("keyup", (e) => { this.shiftDown = e.shiftKey; diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 09861c440..f320f8401 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -431,7 +431,15 @@ export class ComfyUI { defaultValue: true, }); + const promptFilename = this.settings.addSetting({ + id: "Comfy.PromptFilename", + name: "Prompt for filename when saving workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { + id: "comfy-file-input", type: "file", accept: ".json,image/png", style: { display: "none" }, @@ -448,6 +456,7 @@ export class ComfyUI { $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), $el("button.comfy-queue-btn", { + id: "queue-button", textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount), }), @@ -496,9 +505,10 @@ export class ComfyUI { ]), ]), $el("div.comfy-menu-btns", [ - $el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }), + $el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }), $el("button", { $: (b) => (this.queue.button = b), + id: "comfy-view-queue-button", textContent: "View Queue", onclick: () => { this.history.hide(); @@ -507,6 +517,7 @@ export class ComfyUI { }), $el("button", { $: (b) => (this.history.button = b), + id: "comfy-view-history-button", textContent: "View History", onclick: () => { this.queue.hide(); @@ -517,14 +528,23 @@ export class ComfyUI { this.queue.element, this.history.element, $el("button", { + id: "comfy-save-button", textContent: "Save", onclick: () => { + let filename = "workflow.json"; + if (promptFilename.value) { + filename = prompt("Save workflow as:", filename); + if (!filename) return; + if (!filename.toLowerCase().endsWith(".json")) { + filename += ".json"; + } + } const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string const blob = new Blob([json], { type: "application/json" }); const url = URL.createObjectURL(blob); const a = $el("a", { href: url, - download: "workflow.json", + download: filename, style: { display: "none" }, parent: document.body, }); @@ -535,15 +555,15 @@ export class ComfyUI { }, 0); }, }), - $el("button", { textContent: "Load", onclick: () => fileInput.click() }), - $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), - $el("button", { textContent: "Clear", onclick: () => { + $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), + $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), - $el("button", { textContent: "Load Default", onclick: () => { + $el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => { if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } From 74fc7b772656a59b344508480632d9d45f9127de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 16 Apr 2023 01:36:15 -0400 Subject: [PATCH 043/190] custom_nodes paths can now be set in the extra_model_paths.yaml --- extra_model_paths.yaml.example | 2 +- folder_paths.py | 7 +++++-- main.py | 15 ++++++++------- nodes.py | 17 +++++++++-------- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index af784fd69..f421f54dc 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -18,6 +18,6 @@ a111: #other_ui: # base_path: path/to/ui # checkpoints: models/checkpoints - +# custom_nodes: path/custom_nodes diff --git a/folder_paths.py b/folder_paths.py index ab3359347..61f446c96 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -12,8 +12,8 @@ except: folder_names_and_paths = {} - -models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") +base_path = os.path.dirname(os.path.realpath(__file__)) +models_dir = os.path.join(base_path, "models") folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) @@ -28,6 +28,9 @@ folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) + + output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") diff --git a/main.py b/main.py index 9c0a3d8a1..02c700ebc 100644 --- a/main.py +++ b/main.py @@ -81,6 +81,14 @@ if __name__ == "__main__": server = server.PromptServer(loop) q = execution.PromptQueue(server) + extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + if os.path.isfile(extra_model_paths_config_path): + load_extra_path_config(extra_model_paths_config_path) + + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) + init_custom_nodes() server.add_routes() hijack_progress(server) @@ -91,13 +99,6 @@ if __name__ == "__main__": dont_print = args.dont_print_server - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") - if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) - - if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) if args.output_directory: output_dir = os.path.abspath(args.output_directory) diff --git a/nodes.py b/nodes.py index e6ad9434f..c775da00c 100644 --- a/nodes.py +++ b/nodes.py @@ -1178,15 +1178,16 @@ def load_custom_node(module_path): print(f"Cannot import {module_path} module for custom nodes:", e) def load_custom_nodes(): - CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") - possible_modules = os.listdir(CUSTOM_NODE_PATH) - if "__pycache__" in possible_modules: - possible_modules.remove("__pycache__") + node_paths = folder_paths.get_folder_paths("custom_nodes") + for custom_node_path in node_paths: + possible_modules = os.listdir(custom_node_path) + if "__pycache__" in possible_modules: + possible_modules.remove("__pycache__") - for possible_module in possible_modules: - module_path = os.path.join(CUSTOM_NODE_PATH, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - load_custom_node(module_path) + for possible_module in possible_modules: + module_path = os.path.join(custom_node_path, possible_module) + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + load_custom_node(module_path) def init_custom_nodes(): load_custom_nodes() From 22bde7957e18e8f9c4fb206227a6117dae391417 Mon Sep 17 00:00:00 2001 From: Tomoaki Hayasaka Date: Mon, 17 Apr 2023 01:58:33 +0900 Subject: [PATCH 044/190] Fix "Ctrl+Enter doesn't work when textarea has focus" regression introduced in #491. --- web/extensions/core/keybinds.js | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/web/extensions/core/keybinds.js b/web/extensions/core/keybinds.js index 1825007a6..42c228017 100644 --- a/web/extensions/core/keybinds.js +++ b/web/extensions/core/keybinds.js @@ -5,12 +5,6 @@ app.registerExtension({ name: id, init() { const keybindListener = function(event) { - const target = event.composedPath()[0]; - - if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") { - return; - } - const modifierPressed = event.ctrlKey || event.metaKey; // Queue prompt using ctrl or command + enter @@ -19,6 +13,12 @@ app.registerExtension({ return; } + const target = event.composedPath()[0]; + + if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") { + return; + } + const modifierKeyIdMap = { "s": "#comfy-save-button", 83: "#comfy-save-button", From 0ab5c619eafa026d4be1a3f6bf462a6f7f9d25d6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 01:04:54 -0400 Subject: [PATCH 045/190] Clarify in README that it's AMD GPUs. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f610f9497..be2cb8ec5 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ Put your VAE in: models/vae At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10. -### AMD (Linux only) +### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` From 884ea653c8d6fe19b3724f45a04a0d74cd881f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 11:05:15 -0400 Subject: [PATCH 046/190] Add a way for nodes to set a custom CFG function. --- comfy/samplers.py | 5 ++++- comfy/sd.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index ed36442a9..05af6fe88 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -211,7 +211,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) - return uncond + (cond - uncond) * cond_scale + if "sampler_cfg_function" in model_options: + return model_options["sampler_cfg_function"](cond, uncond, cond_scale) + else: + return uncond + (cond - uncond) * cond_scale class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): diff --git a/comfy/sd.py b/comfy/sd.py index 9c632e240..1d7774742 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -250,6 +250,9 @@ class ModelPatcher: def set_model_tomesd(self, ratio): self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} + def set_model_sampler_cfg_function(self, sampler_cfg_function): + self.model_options["sampler_cfg_function"] = sampler_cfg_function + def model_dtype(self): return self.model.diffusion_model.dtype From 6f7852bc47de2fa432672a1b93c1727c0824d78b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 17:24:58 -0400 Subject: [PATCH 047/190] Add a LatentFromBatch node to pick a single latent from a batch. Works before and after sampling. --- nodes.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index c775da00c..c745ce280 100644 --- a/nodes.py +++ b/nodes.py @@ -510,6 +510,24 @@ class EmptyLatentImage: return ({"samples":latent}, ) +class LatentFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "rotate" + + CATEGORY = "latent" + + def rotate(self, samples, batch_index): + s = samples.copy() + s_in = samples["samples"] + batch_index = min(s_in.shape[0] - 1, batch_index) + s["samples"] = s_in[batch_index:batch_index + 1].clone() + s["batch_index"] = batch_index + return (s,) class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -685,7 +703,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] + + generator = torch.manual_seed(seed) + for i in range(batch_index + 1): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if "noise_mask" in latent: noise_mask = latent['noise_mask'] @@ -1073,6 +1097,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentFromBatch": LatentFromBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, From 7b5eb196dbf4248eb6c67af2843cacb28863ce2f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 17 Apr 2023 17:29:22 -0400 Subject: [PATCH 048/190] allows control arrow to edit attention in textarea --- web/extensions/core/editAttention.js | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 web/extensions/core/editAttention.js diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js new file mode 100644 index 000000000..d943290ce --- /dev/null +++ b/web/extensions/core/editAttention.js @@ -0,0 +1,117 @@ +import { app } from "/scripts/app.js"; + +// Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys + +const id = "Comfy.EditAttention"; +app.registerExtension({ +name:id, + init() { + function incrementWeight(weight, delta) { + const floatWeight = parseFloat(weight); + if (isNaN(floatWeight)) return weight; + const newWeight = floatWeight + delta; + if (newWeight < 0) return "0"; + return String(Number(newWeight.toFixed(10))); + } + + function findNearestEnclosure(text, cursorPos) { + let start = cursorPos, end = cursorPos; + let openCount = 0, closeCount = 0; + + // Find opening parenthesis before cursor + while (start >= 0) { + start--; + if (text[start] === "(" && openCount === closeCount) break; + if (text[start] === "(") openCount++; + if (text[start] === ")") closeCount++; + } + if (start < 0) return false; + + openCount = 0; + closeCount = 0; + + // Find closing parenthesis after cursor + while (end < text.length) { + if (text[end] === ")" && openCount === closeCount) break; + if (text[end] === "(") openCount++; + if (text[end] === ")") closeCount++; + end++; + } + if (end === text.length) return false; + + return { start: start + 1, end: end }; + } + + function addWeightToParentheses(text) { + const parenRegex = /^\((.*)\)$/; + const parenMatch = text.match(parenRegex); + + const floatRegex = /:([+-]?(\d*\.)?\d+([eE][+-]?\d+)?)/; + const floatMatch = text.match(floatRegex); + + if (parenMatch && !floatMatch) { + return `(${parenMatch[1]}:1.0)`; + } else { + return text; + } + }; + + function editAttention(event) { + const inputField = event.composedPath()[0]; + const delta = 0.1; + + if (inputField.tagName !== "TEXTAREA") return; + if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; + if (!event.ctrlKey && !event.metaKey) return; + + event.preventDefault(); + + let start = inputField.selectionStart; + let end = inputField.selectionEnd; + let selectedText = inputField.value.substring(start, end); + + // If there is no selection, attempt to find the nearest enclosure + if (!selectedText) { + const nearestEnclosure = findNearestEnclosure(inputField.value, start); + if (nearestEnclosure) { + start = nearestEnclosure.start; + end = nearestEnclosure.end; + selectedText = inputField.value.substring(start, end); + } else { + return; + } + } + + // If the selection ends with a space, remove it + if (selectedText[selectedText.length - 1] === " ") { + selectedText = selectedText.substring(0, selectedText.length - 1); + end -= 1; + } + + // If there are parentheses left and right of the selection, select them + if (inputField.value[start - 1] === "(" && inputField.value[end] === ")") { + start -= 1; + end += 1; + selectedText = inputField.value.substring(start, end); + } + + // If the selection is not enclosed in parentheses, add them + if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") { + console.log("adding parentheses", inputField.value[start], inputField.value[end], selectedText); + selectedText = `(${selectedText})`; + } + + // If the selection does not have a weight, add a weight of 1.0 + selectedText = addWeightToParentheses(selectedText); + + // Increment the weight + const weightDelta = event.key === "ArrowUp" ? delta : -delta; + const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { + return prefix + incrementWeight(weight, weightDelta) + suffix; + }); + + inputField.setRangeText(updatedText, start, end, "select"); + } + window.addEventListener("keydown", editAttention); + }, +}); From f03dade5ab8f17a165d63efc205eb34a2330b7d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 18:19:57 -0400 Subject: [PATCH 049/190] Fix bug. --- nodes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index c745ce280..06b69f453 100644 --- a/nodes.py +++ b/nodes.py @@ -708,8 +708,9 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, batch_index = latent["batch_index"] generator = torch.manual_seed(seed) - for i in range(batch_index + 1): + for i in range(batch_index): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if "noise_mask" in latent: noise_mask = latent['noise_mask'] From b8c636b10d39e77742f3f435bf6b85c3aa806583 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 18:21:24 -0400 Subject: [PATCH 050/190] Lower how much CTRL+arrow key changes the number. --- web/extensions/core/editAttention.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index d943290ce..fe395c3ca 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -58,7 +58,7 @@ name:id, function editAttention(event) { const inputField = event.composedPath()[0]; - const delta = 0.1; + const delta = 0.025; if (inputField.tagName !== "TEXTAREA") return; if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; From 79ba0399d8d70bc655269fc3318455a70d14e180 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 17 Apr 2023 19:02:08 -0400 Subject: [PATCH 051/190] selects current word automatically --- web/extensions/core/editAttention.js | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index fe395c3ca..55201953b 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -70,7 +70,7 @@ name:id, let end = inputField.selectionEnd; let selectedText = inputField.value.substring(start, end); - // If there is no selection, attempt to find the nearest enclosure + // If there is no selection, attempt to find the nearest enclosure, or select the current word if (!selectedText) { const nearestEnclosure = findNearestEnclosure(inputField.value, start); if (nearestEnclosure) { @@ -78,7 +78,18 @@ name:id, end = nearestEnclosure.end; selectedText = inputField.value.substring(start, end); } else { - return; + // Select the current word, find the start and end of the word (first space before and after) + start = inputField.value.substring(0, start).lastIndexOf(" ") + 1; + end = inputField.value.substring(end).indexOf(" ") + end; + // Remove all punctuation at the end and beginning of the word + while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { + start++; + } + while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { + end--; + } + selectedText = inputField.value.substring(start, end); + if (!selectedText) return; } } @@ -97,7 +108,6 @@ name:id, // If the selection is not enclosed in parentheses, add them if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") { - console.log("adding parentheses", inputField.value[start], inputField.value[end], selectedText); selectedText = `(${selectedText})`; } From a962222992479057b104cdd06bf399d2a2cae2fa Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 17 Apr 2023 23:40:44 -0400 Subject: [PATCH 052/190] correctly checks end of the text --- web/extensions/core/editAttention.js | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index 55201953b..206d0830a 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -79,8 +79,16 @@ name:id, selectedText = inputField.value.substring(start, end); } else { // Select the current word, find the start and end of the word (first space before and after) - start = inputField.value.substring(0, start).lastIndexOf(" ") + 1; - end = inputField.value.substring(end).indexOf(" ") + end; + const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1; + const wordEnd = inputField.value.substring(end).indexOf(" "); + // If there is no space after the word, select to the end of the string + if (wordEnd === -1) { + end = inputField.value.length; + } else { + end += wordEnd; + } + start = wordStart; + // Remove all punctuation at the end and beginning of the word while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { start++; From a7c7da68dc8a5e6bf1e316b6b36c4a61c7571445 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 18 Apr 2023 00:22:05 -0600 Subject: [PATCH 053/190] Editattention setting (#533) * Add editAttention delta setting * Update editAttention.js * Update web/extensions/core/editAttention.js Co-authored-by: Karun * Update editAttention.js * Update editAttention.js * Fix setting value --------- Co-authored-by: Karun --- web/extensions/core/editAttention.js | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index 206d0830a..66d4a8373 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -2,10 +2,21 @@ import { app } from "/scripts/app.js"; // Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys -const id = "Comfy.EditAttention"; app.registerExtension({ -name:id, + name: "Comfy.EditAttention", init() { + const editAttentionDelta = app.ui.settings.addSetting({ + id: "Comfy.EditAttention.Delta", + name: "Ctrl+up/down precision", + type: "slider", + attrs: { + min: 0.01, + max: 2, + step: 0.01, + }, + defaultValue: 0.1, + }); + function incrementWeight(weight, delta) { const floatWeight = parseFloat(weight); if (isNaN(floatWeight)) return weight; @@ -58,7 +69,7 @@ name:id, function editAttention(event) { const inputField = event.composedPath()[0]; - const delta = 0.025; + const delta = parseFloat(editAttentionDelta.value); if (inputField.tagName !== "TEXTAREA") return; if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; @@ -125,7 +136,7 @@ name:id, // Increment the weight const weightDelta = event.key === "ArrowUp" ? delta : -delta; const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { - return prefix + incrementWeight(weight, weightDelta) + suffix; + return prefix + incrementWeight(weight, weightDelta) + suffix; }); inputField.setRangeText(updatedText, start, end, "select"); From b016e2769f0a16fcba21c020023413cad68f704b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Apr 2023 02:25:57 -0400 Subject: [PATCH 054/190] Saner range of values. --- web/extensions/core/editAttention.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index 66d4a8373..bebc80b12 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -11,10 +11,10 @@ app.registerExtension({ type: "slider", attrs: { min: 0.01, - max: 2, + max: 0.5, step: 0.01, }, - defaultValue: 0.1, + defaultValue: 0.05, }); function incrementWeight(weight, delta) { From 472b1cc0d881c4009e5a89e0893c5835f3a4c47d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Apr 2023 19:34:07 -0400 Subject: [PATCH 055/190] Add a github action to use pip xformers package for dependencies. --- .../windows_release_cu118_dependencies_2.yml | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/windows_release_cu118_dependencies_2.yml diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml new file mode 100644 index 000000000..a88449527 --- /dev/null +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -0,0 +1,30 @@ +name: "Windows Release cu118 dependencies 2" + +on: + workflow_dispatch: +# push: +# branches: +# - master + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10.9' + + - shell: bash + run: | + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir cu118_python_deps + tar cf cu118_python_deps.tar cu118_python_deps + + - uses: actions/cache/save@v3 + with: + path: cu118_python_deps.tar + key: ${{ runner.os }}-build-cu118 From 3696d1699a6fece2485c063317cf65abbcddb79b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 09:36:19 -0400 Subject: [PATCH 056/190] Add support for GLIGEN textbox model. --- comfy/gligen.py | 343 ++++++++++++++++++ comfy/ldm/modules/attention.py | 16 + .../modules/diffusionmodules/openaimodel.py | 2 + comfy/model_management.py | 6 +- comfy/samplers.py | 57 ++- comfy/sd.py | 22 +- folder_paths.py | 2 + models/gligen/put_gligen_models_here | 0 nodes.py | 71 +++- 9 files changed, 491 insertions(+), 28 deletions(-) create mode 100644 comfy/gligen.py create mode 100644 models/gligen/put_gligen_models_here diff --git a/comfy/gligen.py b/comfy/gligen.py new file mode 100644 index 000000000..8770383e5 --- /dev/null +++ b/comfy/gligen.py @@ -0,0 +1,343 @@ +import torch +from torch import nn, einsum +from ldm.modules.attention import CrossAttention +from inspect import isfunction + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * torch.nn.functional.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class GatedCrossAttentionDense(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + self.attn = CrossAttention( + query_dim=query_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + x = x + self.scale * \ + torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class GatedSelfAttentionDense(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj + # feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = CrossAttention( + query_dim=query_dim, + context_dim=query_dim, + heads=n_heads, + dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + N_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( + self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class GatedSelfAttentionDense2(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj + # feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = CrossAttention( + query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + B, N_visual, _ = x.shape + B, N_ground, _ = objs.shape + + objs = self.linear(objs) + + # sanity check + size_v = math.sqrt(N_visual) + size_g = math.sqrt(N_ground) + assert int(size_v) == size_v, "Visual tokens must be square rootable" + assert int(size_g) == size_g, "Grounding tokens must be square rootable" + size_v = int(size_v) + size_g = int(size_g) + + # select grounding token and resize it to visual token size as residual + out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ + :, N_visual:, :] + out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) + out = torch.nn.functional.interpolate( + out, (size_v, size_v), mode='bicubic') + residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) + + # add residual to visual feature + x = x + self.scale * torch.tanh(self.alpha_attn) * residual + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class FourierEmbedder(): + def __init__(self, num_freqs=64, temperature=100): + + self.num_freqs = num_freqs + self.temperature = temperature + self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) + + @torch.no_grad() + def __call__(self, x, cat_dim=-1): + "x: arbitrary shape of tensor. dim: cat dim" + out = [] + for freq in self.freq_bands: + out.append(torch.sin(freq * x)) + out.append(torch.cos(freq * x)) + return torch.cat(out, cat_dim) + + +class PositionNet(nn.Module): + def __init__(self, in_dim, out_dim, fourier_freqs=8): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy + + self.linears = nn.Sequential( + nn.Linear(self.in_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_positive_feature = torch.nn.Parameter( + torch.zeros([self.in_dim])) + self.null_position_feature = torch.nn.Parameter( + torch.zeros([self.position_dim])) + + def forward(self, boxes, masks, positive_embeddings): + B, N, _ = boxes.shape + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C + + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * \ + masks + (1 - masks) * positive_null + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + objs = self.linears( + torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + assert objs.shape == torch.Size([B, N, self.out_dim]) + return objs + + +class Gligen(nn.Module): + def __init__(self, modules, position_net, key_dim): + super().__init__() + self.module_list = nn.ModuleList(modules) + self.position_net = position_net + self.key_dim = key_dim + self.max_objs = 30 + + def _set_position(self, boxes, masks, positive_embeddings): + objs = self.position_net(boxes, masks, positive_embeddings) + + def func(key, x): + module = self.module_list[key] + return module(x, objs) + return func + + def set_position(self, latent_image_shape, position_params, device): + batch, c, h, w = latent_image_shape + masks = torch.zeros([self.max_objs], device="cpu") + boxes = [] + positive_embeddings = [] + for p in position_params: + x1 = (p[4]) / w + y1 = (p[3]) / h + x2 = (p[4] + p[2]) / w + y2 = (p[3] + p[1]) / h + masks[len(boxes)] = 1.0 + boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] + positive_embeddings += [p[0]] + append_boxes = [] + append_conds = [] + if len(boxes) < self.max_objs: + append_boxes = [torch.zeros( + [self.max_objs - len(boxes), 4], device="cpu")] + append_conds = [torch.zeros( + [self.max_objs - len(boxes), self.key_dim], device="cpu")] + + box_out = torch.cat( + boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) + masks = masks.unsqueeze(0).repeat(batch, 1) + conds = torch.cat(positive_embeddings + + append_conds).unsqueeze(0).repeat(batch, 1, 1) + return self._set_position( + box_out.to(device), + masks.to(device), + conds.to(device)) + + def set_empty(self, latent_image_shape, device): + batch, c, h, w = latent_image_shape + masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) + box_out = torch.zeros([self.max_objs, 4], + device="cpu").repeat(batch, 1, 1) + conds = torch.zeros([self.max_objs, self.key_dim], + device="cpu").repeat(batch, 1, 1) + return self._set_position( + box_out.to(device), + masks.to(device), + conds.to(device)) + + def cleanup(self): + pass + + def get_models(self): + return [self] + +def load_gligen(sd): + sd_k = sd.keys() + output_list = [] + key_dim = 768 + for a in ["input_blocks", "middle_block", "output_blocks"]: + for b in range(20): + k_temp = filter(lambda k: "{}.{}.".format(a, b) + in k and ".fuser." in k, sd_k) + k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) + + n_sd = {} + for k in k_temp: + n_sd[k[1]] = sd[k[0]] + if len(n_sd) > 0: + query_dim = n_sd["linear.weight"].shape[0] + key_dim = n_sd["linear.weight"].shape[1] + + if key_dim == 768: # SD1.x + n_heads = 8 + d_head = query_dim // n_heads + else: + d_head = 64 + n_heads = query_dim // d_head + + gated = GatedSelfAttentionDense( + query_dim, key_dim, n_heads, d_head) + gated.load_state_dict(n_sd, strict=False) + output_list.append(gated) + + if "position_net.null_positive_feature" in sd_k: + in_dim = sd["position_net.null_positive_feature"].shape[0] + out_dim = sd["position_net.linears.4.weight"].shape[0] + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.position_net = PositionNet(in_dim, out_dim) + w.load_state_dict(sd, strict=False) + + gligen = Gligen(output_list, w.position_net, key_dim) + return gligen diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c83387348..98dbda635 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): + current_index = None + if "current_index" in transformer_options: + current_index = transformer_options["current_index"] + if "patches" in transformer_options: + transformer_patches = transformer_options["patches"] + else: + transformer_patches = {} + n = self.norm1(x) if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) @@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module): n = self.attn1(n, context=context if self.disable_self_attn else None) x += n + if "middle_patch" in transformer_patches: + patch = transformer_patches["middle_patch"] + for p in patch: + x = p(current_index, x) + n = self.norm2(x) n = self.attn2(n, context=context) x += n x = self.ff(self.norm3(x)) + x + + if current_index is not None: + transformer_options["current_index"] += 1 return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 8a4e8b3e1..4c69c8567 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -782,6 +782,8 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) + transformer_options["current_index"] = 0 + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" diff --git a/comfy/model_management.py b/comfy/model_management.py index 76455e4a2..a0d1313d2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -176,7 +176,7 @@ def load_model_gpu(model): model_accelerated = True return current_loaded_model -def load_controlnet_gpu(models): +def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state if vram_state == VRAMState.CPU: @@ -186,6 +186,10 @@ def load_controlnet_gpu(models): #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return + models = [] + for m in control_models: + models += m.get_models() + for m in current_gpu_controlnets: if m not in models: m.cpu() diff --git a/comfy/samplers.py b/comfy/samplers.py index 05af6fe88..31968e185 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con control = None if 'control' in cond[1]: control = cond[1]['control'] - return (input_x, mult, conditionning, area, control) + + patches = None + if 'gligen' in cond[1]: + gligen = cond[1]['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + return (input_x, mult, conditionning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def can_concat_cond(c1, c2): if c1[0].shape != c2[0].shape: return False + + #control if (c1[4] is None) != (c2[4] is None): return False if c1[4] is not None: if c1[4] is not c2[4]: return False + #patches + if (c1[5] is None) != (c2[5] is None): + return False + if (c1[5] is not None): + if c1[5] is not c2[5]: + return False + return cond_equal_size(c1[2], c2[2]) def cond_cat(c_list): @@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con cond_or_uncond = [] area = [] control = None + patches = None for x in to_batch: o = to_run.pop(x) p = o[0] @@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con area += [p[3]] cond_or_uncond += [o[1]] control = p[4] + patches = p[5] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) @@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + transformer_options = {} if 'transformer_options' in model_options: - c['transformer_options'] = model_options['transformer_options'] + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + transformer_options["patches"] = patches + + c['transformer_options'] = transformer_options output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x @@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] - -def apply_control_net_to_equal_area(conds, uncond): +def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] uncond_cnets = [] @@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond): for t in range(len(conds)): x = conds[t] if 'area' not in x[1]: - if 'control' in x[1] and x[1]['control'] is not None: - cond_cnets.append(x[1]['control']) + if name in x[1] and x[1][name] is not None: + cond_cnets.append(x[1][name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] if 'area' not in x[1]: - if 'control' in x[1] and x[1]['control'] is not None: - uncond_cnets.append(x[1]['control']) + if name in x[1] and x[1][name] is not None: + uncond_cnets.append(x[1][name]) else: uncond_other.append((x, t)) @@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if 'control' in o[1] and o[1]['control'] is not None: + if name in o[1] and o[1][name] is not None: n = o[1].copy() - n['control'] = cond_cnets[x] + n[name] = uncond_fill_func(cond_cnets, x) uncond += [[o[0], n]] else: n = o[1].copy() - n['control'] = cond_cnets[x] + n[name] = uncond_fill_func(cond_cnets, x) uncond[temp[1]] = [o[0], n] + def encode_adm(noise_augmentor, conds, batch_size, device): for t in range(len(conds)): x = conds[t] @@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds + class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", @@ -466,7 +498,8 @@ class KSampler: for c in negative: create_cond_with_same_area_if_none(positive, c) - apply_control_net_to_equal_area(positive, negative) + apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if self.model.model.diffusion_model.dtype == torch.float16: precision_scope = torch.autocast diff --git a/comfy/sd.py b/comfy/sd.py index 1d7774742..211acd70e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -13,6 +13,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision +from . import gligen def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -378,7 +379,7 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode_from_tokens(self, tokens): + def encode_from_tokens(self, tokens, return_pooled=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: @@ -388,6 +389,10 @@ class CLIP: except Exception as e: self.patcher.unpatch_model() raise e + if return_pooled: + eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__) + pooled = cond[:, eos_token_index] + return cond, pooled return cond def encode(self, text): @@ -564,10 +569,10 @@ class ControlNet: c.strength = self.strength return c - def get_control_models(self): + def get_models(self): out = [] if self.previous_controlnet is not None: - out += self.previous_controlnet.get_control_models() + out += self.previous_controlnet.get_models() out.append(self.control_model) return out @@ -737,10 +742,10 @@ class T2IAdapter: del self.cond_hint self.cond_hint = None - def get_control_models(self): + def get_models(self): out = [] if self.previous_controlnet is not None: - out += self.previous_controlnet.get_control_models() + out += self.previous_controlnet.get_models() return out def load_t2i_adapter(t2i_data): @@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None): clip.load_from_state_dict(clip_data) return clip +def load_gligen(ckpt_path): + data = utils.load_torch_file(ckpt_path) + model = gligen.load_gligen(data) + if model_management.should_use_fp16(): + model = model.half() + return model + def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): with open(config_path, 'r') as stream: config = yaml.safe_load(stream) diff --git a/folder_paths.py b/folder_paths.py index 61f446c96..3c4ad3711 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) +folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) + folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) diff --git a/models/gligen/put_gligen_models_here b/models/gligen/put_gligen_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 06b69f453..8555f272a 100644 --- a/nodes.py +++ b/nodes.py @@ -490,6 +490,51 @@ class unCLIPConditioning: c.append(n) return (c, ) +class GLIGENLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}} + + RETURN_TYPES = ("GLIGEN",) + FUNCTION = "load_gligen" + + CATEGORY = "_for_testing/gligen" + + def load_gligen(self, gligen_name): + gligen_path = folder_paths.get_full_path("gligen", gligen_name) + gligen = comfy.sd.load_gligen(gligen_path) + return (gligen,) + +class GLIGENTextBoxApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_to": ("CONDITIONING", ), + "clip": ("CLIP", ), + "gligen_textbox_model": ("GLIGEN", ), + "text": ("STRING", {"multiline": True}), + "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "_for_testing/gligen" + + def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): + c = [] + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + for t in conditioning_to: + n = [t[0], t[1].copy()] + position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] + prev = [] + if "gligen" in n[1]: + prev = n[1]['gligen'][2] + + n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params) + c.append(n) + return (c, ) class EmptyLatentImage: def __init__(self, device="cpu"): @@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative_copy = [] control_nets = [] + def get_models(cond): + models = [] + for c in cond: + if 'control' in c[1]: + models += [c[1]['control']] + if 'gligen' in c[1]: + models += [c[1]['gligen'][1]] + return models + for p in positive: t = p[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) - if 'control' in p[1]: - control_nets += [p[1]['control']] positive_copy += [[t] + p[1:]] for n in negative: t = n[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) - if 'control' in n[1]: - control_nets += [n[1]['control']] negative_copy += [[t] + n[1:]] - control_net_models = [] - for x in control_nets: - control_net_models += x.get_control_models() - comfy.model_management.load_controlnet_gpu(control_net_models) + models = get_models(positive) + get_models(negative) + comfy.model_management.load_controlnet_gpu(models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - for c in control_nets: - c.cleanup() + for m in models: + m.cleanup() out = latent.copy() out["samples"] = samples @@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeTiled": VAEEncodeTiled, "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, + "GLIGENLoader": GLIGENLoader, + "GLIGENTextBoxApply": GLIGENTextBoxApply, + "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, } From 781b724ac667e42900c331988f356a85670c0ec5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 11:30:18 -0400 Subject: [PATCH 057/190] Add GLIGEN model link to colab. --- notebooks/comfyui_colab.ipynb | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c088de89c..c1982d8be 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -138,6 +138,11 @@ "# Controlnet Preprocessor nodes by Fannovel16\n", "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", "\n", + "\n", + "# GLIGEN\n", + "#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n", + "\n", + "\n", "# ESRGAN upscale model\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", From 2d546d510d1f7919bbae3ac08108e0d05e9c0bae Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 11:47:49 -0400 Subject: [PATCH 058/190] Add gligen entry to extra_model_paths example. --- extra_model_paths.yaml.example | 1 + 1 file changed, 1 insertion(+) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index f421f54dc..ac1ffe9d2 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -18,6 +18,7 @@ a111: #other_ui: # base_path: path/to/ui # checkpoints: models/checkpoints +# gligen: models/gligen # custom_nodes: path/custom_nodes From 96b57a9ad6447b95921b91e5f52fb3684f73514f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 21:11:38 -0400 Subject: [PATCH 059/190] Don't pass adm to model when it doesn't support it. --- comfy/samplers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 31968e185..19ebc97d9 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -36,8 +36,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con strength = cond[1]['strength'] adm_cond = None - if 'adm' in cond[1]: - adm_cond = cond[1]['adm'] + if 'adm_encoded' in cond[1]: + adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] mult = torch.ones_like(input_x) * strength @@ -405,7 +405,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): else: adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) x[1] = x[1].copy() - x[1]["adm"] = torch.cat([adm_out] * batch_size) + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) return conds From 94e9798a4b614627805be197aa3da415a1de7ee4 Mon Sep 17 00:00:00 2001 From: omar92 Date: Thu, 20 Apr 2023 06:19:56 +0200 Subject: [PATCH 060/190] when drag from node input or output show all possible nodes that you can connect --- web/extensions/core/slotDefaults.js | 50 ++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 0b6a0a150..3ff5fdb06 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -1,21 +1,49 @@ import { app } from "/scripts/app.js"; - +import { ComfyWidgets } from "/scripts/widgets.js"; // Adds defaults for quickly adding nodes with middle click on the input/output app.registerExtension({ name: "Comfy.SlotDefaults", init() { LiteGraph.middle_click_slot_add_default_node = true; - LiteGraph.slot_types_default_in = { - MODEL: "CheckpointLoaderSimple", - LATENT: "EmptyLatentImage", - VAE: "VAELoader", - }; + }, + async beforeRegisterNodeDef(nodeType, nodeData, app) { + var nodeId = nodeData.name; + var inputs = []; + //if (nodeData["input"]["optional"] != undefined) { + // inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]); + //} else { + inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logica to create node with optional inputs + //} + for (const inputKey in inputs) { + var input = (inputs[inputKey]); + //make sure input[0] is a string + if (typeof input[0] !== "string") continue; + + // for (const slotKey in inputs[inputKey]) { + var type = input[0] + if (type in ComfyWidgets) { + var customProperties = input[1] + //console.log(customProperties) + if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input + } + + if (!(type in LiteGraph.slot_types_default_out)) { + LiteGraph.slot_types_default_out[type] = ["Reroute"]; + } + if (LiteGraph.slot_types_default_out[type].includes(nodeId)) continue; + LiteGraph.slot_types_default_out[type].push(nodeId); + // } + } + + var outputs = nodeData["output"]; + for (const key in outputs) { + var type = outputs[key]; + if (!(type in LiteGraph.slot_types_default_in)) { + LiteGraph.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() + } + LiteGraph.slot_types_default_in[type].push(nodeId); + } - LiteGraph.slot_types_default_out = { - LATENT: "VAEDecode", - IMAGE: "SaveImage", - CLIP: "CLIPTextEncode", - }; }, }); From 5229c1f972b4130d5d0ddc19362604c6ec57d1fd Mon Sep 17 00:00:00 2001 From: omar92 Date: Thu, 20 Apr 2023 21:13:14 +0200 Subject: [PATCH 061/190] add option on the settings to change the number of the suggestions --- web/extensions/core/slotDefaults.js | 61 ++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 3ff5fdb06..04baadc6a 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -4,46 +4,69 @@ import { ComfyWidgets } from "/scripts/widgets.js"; app.registerExtension({ name: "Comfy.SlotDefaults", + suggestionsNumber: null, init() { LiteGraph.middle_click_slot_add_default_node = true; + this.suggestionsNumber = app.ui.settings.addSetting({ + id: "Comfy.NodeSuggestions.number", + name: "number of nodes suggestions", + type: "slider", + attrs: { + min: 1, + max: 100, + step: 1, + }, + defaultValue: 5, + onChange: (newVal, oldVal) => { + this.setDefaults(newVal); + } + }); }, + slot_types_default_out: {}, + slot_types_default_in: {}, async beforeRegisterNodeDef(nodeType, nodeData, app) { - var nodeId = nodeData.name; + var nodeId = nodeData.name; var inputs = []; - //if (nodeData["input"]["optional"] != undefined) { - // inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]); - //} else { - inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logica to create node with optional inputs - //} + inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logical to create node with optional inputs for (const inputKey in inputs) { var input = (inputs[inputKey]); - //make sure input[0] is a string if (typeof input[0] !== "string") continue; - // for (const slotKey in inputs[inputKey]) { var type = input[0] if (type in ComfyWidgets) { var customProperties = input[1] - //console.log(customProperties) if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input } - if (!(type in LiteGraph.slot_types_default_out)) { - LiteGraph.slot_types_default_out[type] = ["Reroute"]; + if (!(type in this.slot_types_default_out)) { + this.slot_types_default_out[type] = ["Reroute"]; } - if (LiteGraph.slot_types_default_out[type].includes(nodeId)) continue; - LiteGraph.slot_types_default_out[type].push(nodeId); - // } - } + if (this.slot_types_default_out[type].includes(nodeId)) continue; + this.slot_types_default_out[type].push(nodeId); + } var outputs = nodeData["output"]; for (const key in outputs) { var type = outputs[key]; - if (!(type in LiteGraph.slot_types_default_in)) { - LiteGraph.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() + if (!(type in this.slot_types_default_in)) { + this.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() } - LiteGraph.slot_types_default_in[type].push(nodeId); - } + this.slot_types_default_in[type].push(nodeId); + } + var maxNum = this.suggestionsNumber ? this.suggestionsNumber.value : 5; + this.setDefaults(maxNum); }, + setDefaults(maxNum) { + + LiteGraph.slot_types_default_out = {}; + LiteGraph.slot_types_default_in = {}; + + for (const type in this.slot_types_default_out) { + LiteGraph.slot_types_default_out[type] = this.slot_types_default_out[type].slice(0, maxNum); + } + for (const type in this.slot_types_default_in) { + LiteGraph.slot_types_default_in[type] = this.slot_types_default_in[type].slice(0, maxNum); + } + } }); From 31e60adb2802874a5889623a83149faa32924a98 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 20 Apr 2023 17:30:10 -0400 Subject: [PATCH 062/190] Add GLIGEN example to README. --- README.md | 1 + nodes.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index be2cb8ec5..bf16006bf 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) +- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - Starts up very fast. - Works fully offline: will never download anything. - [Config file](extra_model_paths.yaml.example) to set the search paths for models. diff --git a/nodes.py b/nodes.py index 8555f272a..48c3ee9c3 100644 --- a/nodes.py +++ b/nodes.py @@ -498,7 +498,7 @@ class GLIGENLoader: RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" - CATEGORY = "_for_testing/gligen" + CATEGORY = "loaders" def load_gligen(self, gligen_name): gligen_path = folder_paths.get_full_path("gligen", gligen_name) @@ -520,7 +520,7 @@ class GLIGENTextBoxApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "_for_testing/gligen" + CATEGORY = "conditioning/gligen" def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] From d2ef3465ca838e528008cb5e20b40d25079d5176 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 20 Apr 2023 18:23:51 -0600 Subject: [PATCH 063/190] Improve current word selection --- web/extensions/core/editAttention.js | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index bebc80b12..cc51a04e5 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -89,24 +89,17 @@ app.registerExtension({ end = nearestEnclosure.end; selectedText = inputField.value.substring(start, end); } else { - // Select the current word, find the start and end of the word (first space before and after) - const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1; - const wordEnd = inputField.value.substring(end).indexOf(" "); - // If there is no space after the word, select to the end of the string - if (wordEnd === -1) { - end = inputField.value.length; - } else { - end += wordEnd; + // Select the current word, find the start and end of the word + const delimiters = " .,\\/!?%^*;:{}=-_`~()\r\n\t"; + + while (!delimiters.includes(inputField.value[start - 1]) && start > 0) { + start--; + } + + while (!delimiters.includes(inputField.value[end]) && end < inputField.value.length) { + end++; } - start = wordStart; - // Remove all punctuation at the end and beginning of the word - while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { - start++; - } - while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { - end--; - } selectedText = inputField.value.substring(start, end); if (!selectedText) return; } From 907010e0824eeab12c5948e5afa4df6d0934be9a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 20 Apr 2023 23:58:25 -0400 Subject: [PATCH 064/190] Remove some useless code. --- comfy/samplers.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 19ebc97d9..15527224e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,23 +7,6 @@ from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] < (96 * 96): #TODO check memory instead - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - else: - cond = self.inner_model(x, sigma, cond=cond) - uncond = self.inner_model(x, sigma, cond=uncond) - return uncond + (cond - uncond) * cond_scale - - #The main sampling function shared by all the samplers #Returns predicted noise def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): From 98ae4bbfdee1ea9da62e3d22a3c6428032a78398 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 20 Apr 2023 23:55:20 -0600 Subject: [PATCH 065/190] Remove brackets if weight == 1 --- web/extensions/core/editAttention.js | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index cc51a04e5..b937bb103 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -128,8 +128,13 @@ app.registerExtension({ // Increment the weight const weightDelta = event.key === "ArrowUp" ? delta : -delta; - const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { - return prefix + incrementWeight(weight, weightDelta) + suffix; + const updatedText = selectedText.replace(/\((.*):(\d+(?:\.\d+)?)\)/, (match, text, weight) => { + weight = incrementWeight(weight, weightDelta); + if (weight == 1) { + return text; + } else { + return `(${text}:${weight})`; + } }); inputField.setRangeText(updatedText, start, end, "select"); From 989acd769a6b5f3a5d6e3cd03fafbd9668c2dbdf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Apr 2023 23:43:38 -0400 Subject: [PATCH 066/190] Cleanup. --- web/extensions/core/slotDefaults.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 04baadc6a..3ec605900 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -54,7 +54,7 @@ app.registerExtension({ this.slot_types_default_in[type].push(nodeId); } - var maxNum = this.suggestionsNumber ? this.suggestionsNumber.value : 5; + var maxNum = this.suggestionsNumber.value; this.setDefaults(maxNum); }, setDefaults(maxNum) { From 6908f9c94992b32fbb96be0f6cd8c5b362d72a77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Apr 2023 14:30:39 -0400 Subject: [PATCH 067/190] This makes pytorch2.0 attention perform a bit faster. --- comfy/ldm/modules/attention.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 98dbda635..c27d032a3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module): b, _, _ = q.shape q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), + lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), (q, k, v), ) @@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module): if exists(mask): raise NotImplementedError out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) + out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) ) return self.to_out(out) From ee030d281bbd25d385ba9ca10badb66b487cca21 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 22 Apr 2023 16:02:26 -0700 Subject: [PATCH 068/190] Add support for multiple unique inpainting masks This enables workflows like "Inpaint at full resolution" when using batch sizes greater than 1. --- nodes.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index 48c3ee9c3..9335d5243 100644 --- a/nodes.py +++ b/nodes.py @@ -171,24 +171,28 @@ class VAEEncodeForInpaint: def encode(self, vae, pixels, mask): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 - mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0).unsqueeze(0) + elif len(mask.shape) < 4: + mask = mask.unsqueeze(1) + mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - mask = mask[:x,:y] + mask = mask[:,:x,:y,:] #grow mask by a few pixels to keep things seamless in latent space kernel_tensor = torch.ones((1, 1, 6, 6)) - mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1) - m = (1.0 - mask.round()) + mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) + m = (1.0 - mask.round()).squeeze(1) for i in range(3): pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] *= m pixels[:,:,:,i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) + return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, ) class CheckpointLoader: @classmethod @@ -759,10 +763,15 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + if len(noise_mask.shape) < 3: + noise_mask = noise_mask.unsqueeze(0).unsqueeze(0) + elif len(noise_mask.shape) < 4: + noise_mask = noise_mask.unsqueeze(1) + noise_mask = torch.nn.functional.interpolate(noise_mask, size=(noise.shape[2], noise.shape[3]), mode="bilinear") noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) + if noise_mask.shape[0] < latent_image.shape[0]: + noise_mask = noise_mask.repeat(latent_image.shape[0] // noise_mask.shape[0], 1, 1, 1) noise_mask = noise_mask.to(device) real_model = None From c8355ed39ff39a10eb7a3d262f278dc99ad2e73b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 23 Apr 2023 10:31:21 +0100 Subject: [PATCH 069/190] use window.name instead of session storage - prevents duplicate stealing session id --- web/scripts/api.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/scripts/api.js b/web/scripts/api.js index 2b90c2abc..d29faa5ba 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -35,7 +35,7 @@ class ComfyApi extends EventTarget { } let opened = false; - let existingSession = sessionStorage["Comfy.SessionId"] || ""; + let existingSession = window.name; if (existingSession) { existingSession = "?clientId=" + existingSession; } @@ -75,7 +75,7 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; - sessionStorage["Comfy.SessionId"] = this.clientId; + window.name = this.clientId; } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break; From 5282f5643476ba0f55197c3ca8b72ce43525b025 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 12:35:25 -0400 Subject: [PATCH 070/190] Implement Linear hypernetworks. Add a HypernetworkLoader node to use hypernetworks. --- comfy/ldm/modules/attention.py | 69 +++++++++++++--- comfy/model_management.py | 3 + comfy/samplers.py | 10 ++- comfy/sd.py | 23 ++++++ comfy/utils.py | 7 +- comfy_extras/nodes_hypernetwork.py | 87 +++++++++++++++++++++ folder_paths.py | 1 + models/hypernetworks/put_hypernetworks_here | 0 nodes.py | 1 + 9 files changed, 185 insertions(+), 16 deletions(-) create mode 100644 comfy_extras/nodes_hypernetwork.py create mode 100644 models/hypernetworks/put_hypernetworks_here diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c27d032a3..ce7180d91 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads query = self.to_q(x) context = default(context, x) key = self.to_k(context) - value = self.to_v(context) + if value is not None: + value = self.to_v(value) + else: + value = self.to_v(context) + del context, x query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) @@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) - v_in = self.to_v(context) + if value is not None: + v_in = self.to_v(value) + del value + else: + v_in = self.to_v(context) del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) @@ -350,13 +358,17 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module): transformer_patches = {} n = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) - n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: - n = self.attn1(n, context=context if self.disable_self_attn else None) + n = self.attn1(n, context=context_attn1, value=value_attn1) x += n if "middle_patch" in transformer_patches: @@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module): x = p(current_index, x) n = self.norm2(x) - n = self.attn2(n, context=context) + + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + + n = self.attn2(n, context=context_attn2, value=value_attn2) x += n x = self.ff(self.norm3(x)) + x diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d1313d2..6e3a03530 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -133,6 +133,7 @@ def unload_model(): #never unload models from GPU on high vram if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() + current_loaded_model.model_patches_to("cpu") current_loaded_model.unpatch_model() current_loaded_model = None @@ -156,6 +157,8 @@ def load_model_gpu(model): except Exception as e: model.unpatch_model() raise e + + model.model_patches_to(get_torch_device()) current_loaded_model = model if vram_state == VRAMState.CPU: pass diff --git a/comfy/samplers.py b/comfy/samplers.py index 15527224e..b860f25f1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con transformer_options = model_options['transformer_options'].copy() if patches is not None: - transformer_options["patches"] = patches + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches c['transformer_options'] = transformer_options diff --git a/comfy/sd.py b/comfy/sd.py index 211acd70e..92dbb931d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -254,6 +254,29 @@ class ModelPatcher: def set_model_sampler_cfg_function(self, sampler_cfg_function): self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + def model_dtype(self): return self.model.diffusion_model.dtype diff --git a/comfy/utils.py b/comfy/utils.py index 0380b91dd..68f93403c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,14 @@ import torch -def load_torch_file(ckpt): +def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: - pl_sd = torch.load(ckpt, map_location="cpu") + if safe_load: + pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + else: + pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py new file mode 100644 index 000000000..db2f8695c --- /dev/null +++ b/comfy_extras/nodes_hypernetwork.py @@ -0,0 +1,87 @@ +import comfy.utils +import folder_paths +import torch + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: + print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + return None + + out = {} + + for d in sd: + try: + dim = int(d) + except: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) + layers = [] + + for lin_name in linears: + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers += [layer] + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + def __call__(self, current_index, q, k, v): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + +class HypernetworkLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_hypernetwork" + + CATEGORY = "_for_testing" + + def load_hypernetwork(self, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return (model_hypernetwork,) + +NODE_CLASS_MAPPINGS = { + "HypernetworkLoader": HypernetworkLoader +} diff --git a/folder_paths.py b/folder_paths.py index 3c4ad3711..bb0d65524 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) +folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/models/hypernetworks/put_hypernetworks_here b/models/hypernetworks/put_hypernetworks_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 48c3ee9c3..6ca73fa0c 100644 --- a/nodes.py +++ b/nodes.py @@ -1268,6 +1268,7 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) From 2a09e2aa27620c492f694b66cc10c5f41b101c12 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sun, 23 Apr 2023 20:02:08 +0200 Subject: [PATCH 071/190] refactor/split various bits of code for sampling --- comfy/sample.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ comfy/samplers.py | 64 +++++++++++++++++++++++++++-------------------- nodes.py | 60 +++++++------------------------------------- 3 files changed, 108 insertions(+), 78 deletions(-) create mode 100644 comfy/sample.py diff --git a/comfy/sample.py b/comfy/sample.py new file mode 100644 index 000000000..ede89890b --- /dev/null +++ b/comfy/sample.py @@ -0,0 +1,62 @@ +import torch +import comfy.model_management + + +def prepare_noise(latent, seed, disable_noise): + latent_image = latent["samples"] + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] + + generator = torch.manual_seed(seed) + for i in range(batch_index): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + return noise + +def create_mask(latent, noise): + noise_mask = None + device = comfy.model_management.get_torch_device() + if "noise_mask" in latent: + noise_mask = latent['noise_mask'] + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) + return noise_mask + +def broadcast_cond(cond, noise): + device = comfy.model_management.get_torch_device() + copy = [] + for p in cond: + t = p[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + copy += [[t] + p[1:]] + return copy + +def load_c_nets(positive, negative): + def get_models(cond): + models = [] + for c in cond: + if 'control' in c[1]: + models += [c[1]['control']] + if 'gligen' in c[1]: + models += [c[1]['gligen'][1]] + return models + + return get_models(positive) + get_models(negative) + +def load_additional_models(positive, negative): + models = load_c_nets(positive, negative) + comfy.model_management.load_controlnet_gpu(models) + return models + +def cleanup_additional_models(models): + for m in models: + m.cleanup() \ No newline at end of file diff --git a/comfy/samplers.py b/comfy/samplers.py index 15527224e..541a8db8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -392,6 +392,38 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds +def calculate_sigmas(model, steps, scheduler, sampler): + """ + Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique + """ + if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)): + model = CFGNoisePredictor(model) + if model.inner_model.parameterization == "v": + model = CompVisVDenoiser(model, quantize=True) + else: + model = k_diffusion_external.CompVisDenoiser(model, quantize=True) + + sigmas = None + + discard_penultimate_sigma = False + if sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + if scheduler == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max)) + elif scheduler == "normal": + sigmas = model.get_sigmas(steps) + elif scheduler == "simple": + sigmas = simple_scheduler(model, steps) + elif scheduler == "ddim_uniform": + sigmas = ddim_scheduler(model, steps) + else: + print("error invalid scheduler", scheduler) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] @@ -421,41 +453,19 @@ class KSampler: self.denoise = denoise self.model_options = model_options - def _calculate_sigmas(self, steps): - sigmas = None - - discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral']: - steps += 1 - discard_penultimate_sigma = True - - if self.scheduler == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) - elif self.scheduler == "normal": - sigmas = self.model_wrap.get_sigmas(steps).to(self.device) - elif self.scheduler == "simple": - sigmas = simple_scheduler(self.model_wrap, steps).to(self.device) - elif self.scheduler == "ddim_uniform": - sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device) - else: - print("error invalid scheduler", self.scheduler) - - if discard_penultimate_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - return sigmas - def set_steps(self, steps, denoise=None): self.steps = steps if denoise is None or denoise > 0.9999: - self.sigmas = self._calculate_sigmas(steps) + self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device) else: new_steps = int(steps/denoise) - sigmas = self._calculate_sigmas(new_steps) + sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): - sigmas = self.sigmas + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None): + if sigmas is None: + sigmas = self.sigmas sigma_min = self.sigma_min if last_step is not None and last_step < (len(sigmas) - 1): diff --git a/nodes.py b/nodes.py index 48c3ee9c3..601661864 100644 --- a/nodes.py +++ b/nodes.py @@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co import comfy.diffusers_convert import comfy.samplers +import comfy.sample import comfy.sd import comfy.utils @@ -739,31 +740,12 @@ class SetLatentNoiseMask: s["noise_mask"] = mask return (s,) - def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): - latent_image = latent["samples"] - noise_mask = None device = comfy.model_management.get_torch_device() + latent_image = latent["samples"] - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] - - generator = torch.manual_seed(seed) - for i in range(batch_index): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - - if "noise_mask" in latent: - noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) - noise_mask = noise_mask.to(device) + noise = comfy.sample.prepare_noise(latent, seed, disable_noise) + noise_mask = comfy.sample.create_mask(latent, noise) real_model = None comfy.model_management.load_model_gpu(model) @@ -772,34 +754,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise = noise.to(device) latent_image = latent_image.to(device) - positive_copy = [] - negative_copy = [] + positive_copy = comfy.sample.broadcast_cond(positive, noise) + negative_copy = comfy.sample.broadcast_cond(negative, noise) - control_nets = [] - def get_models(cond): - models = [] - for c in cond: - if 'control' in c[1]: - models += [c[1]['control']] - if 'gligen' in c[1]: - models += [c[1]['gligen'][1]] - return models - - for p in positive: - t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - positive_copy += [[t] + p[1:]] - for n in negative: - t = n[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - negative_copy += [[t] + n[1:]] - - models = get_models(positive) + get_models(negative) - comfy.model_management.load_controlnet_gpu(models) + models = comfy.sample.load_additional_models(positive, negative) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -809,8 +767,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - for m in models: - m.cleanup() + + comfy.sample.cleanup_additional_models(models) out = latent.copy() out["samples"] = samples From 5818539743bd390a282a19d7e480177c31bc222b Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sun, 23 Apr 2023 20:09:09 +0200 Subject: [PATCH 072/190] add docstrings --- comfy/sample.py | 25 ++++++++++++++----------- nodes.py | 6 +++++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index ede89890b..981781b53 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,22 +2,21 @@ import torch import comfy.model_management -def prepare_noise(latent, seed, disable_noise): +def prepare_noise(latent, seed): + """creates random noise given a LATENT and a seed""" latent_image = latent["samples"] - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] - generator = torch.manual_seed(seed) - for i in range(batch_index): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + generator = torch.manual_seed(seed) + for i in range(batch_index): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise def create_mask(latent, noise): + """creates a mask for a given LATENT and noise""" noise_mask = None device = comfy.model_management.get_torch_device() if "noise_mask" in latent: @@ -30,6 +29,7 @@ def create_mask(latent, noise): return noise_mask def broadcast_cond(cond, noise): + """broadcasts conditioning to the noise batch size""" device = comfy.model_management.get_torch_device() copy = [] for p in cond: @@ -41,6 +41,7 @@ def broadcast_cond(cond, noise): return copy def load_c_nets(positive, negative): + """loads control nets in positive and negative conditioning""" def get_models(cond): models = [] for c in cond: @@ -53,10 +54,12 @@ def load_c_nets(positive, negative): return get_models(positive) + get_models(negative) def load_additional_models(positive, negative): + """loads additional models in positive and negative conditioning""" models = load_c_nets(positive, negative) comfy.model_management.load_controlnet_gpu(models) return models def cleanup_additional_models(models): + """cleanup additional models that were loaded""" for m in models: m.cleanup() \ No newline at end of file diff --git a/nodes.py b/nodes.py index a70668fd7..b8c6d350f 100644 --- a/nodes.py +++ b/nodes.py @@ -744,7 +744,11 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, device = comfy.model_management.get_torch_device() latent_image = latent["samples"] - noise = comfy.sample.prepare_noise(latent, seed, disable_noise) + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + noise = comfy.sample.prepare_noise(latent, seed) + noise_mask = comfy.sample.create_mask(latent, noise) real_model = None From f7a821881476cbd52a513877a9ffe35e6702b850 Mon Sep 17 00:00:00 2001 From: ltdrdata <128333288+ltdrdata@users.noreply.github.com> Date: Mon, 24 Apr 2023 04:58:55 +0900 Subject: [PATCH 073/190] Add clipspace feature. (#541) * Add clipspace feature. * feat: copy content to clipspace * feat: paste content from clipspace Extend validation to allow for validating annotated_path in addition to other parameters. Add support for annotated_filepath in folder_paths function. Generalize the '/upload/image' API to allow for uploading images to the 'input', 'temp', or 'output' directories. * rename contentClipboard -> clipspace * Do deep copy for imgs on copy to clipspace. * add original_imgs into clipspace * Preserve the original image when 'imgs' are modified * robust patch & refactoring folder_paths about annotated_filepath * Only show the Paste menu if the ComfyApp.clipspace is not empty * instant refresh on paste force triggering 'changed' on paste action * subfolder fix on paste logic attach subfolder if subfolder isn't empty --------- Co-authored-by: Lt.Dr.Data --- execution.py | 8 ++++- folder_paths.py | 40 ++++++++++++++++++++++ nodes.py | 8 ++--- server.py | 15 ++++++--- web/scripts/app.js | 83 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 9 deletions(-) diff --git a/execution.py b/execution.py index 73be6db03..b062deeb1 100644 --- a/execution.py +++ b/execution.py @@ -11,6 +11,7 @@ import torch import nodes import comfy.model_management +import folder_paths def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -250,7 +251,12 @@ def validate_inputs(prompt, item): return (False, "Value bigger than max. {}, {}".format(class_type, x)) if isinstance(type_input, list): - if val not in type_input: + is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]") + if is_annotated_path: + if not folder_paths.exists_annotated_filepath(val): + return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val)) + + elif val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") diff --git a/folder_paths.py b/folder_paths.py index bb0d65524..99a016695 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -69,6 +69,46 @@ def get_directory_by_type(type_name): return None +# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format +# otherwise use default_path as base_dir +def touch_annotated_filepath(name): + if name.endswith("[output]"): + base_dir = get_output_directory() + name = name[:-9] + elif name.endswith("[input]"): + base_dir = get_input_directory() + name = name[:-8] + elif name.endswith("[temp]"): + base_dir = get_temp_directory() + name = name[:-7] + else: + return name, None + + return name, base_dir + + +def get_annotated_filepath(name, default_dir=None): + name, base_dir = touch_annotated_filepath(name) + + if base_dir is None: + if default_dir is not None: + base_dir = default_dir + else: + base_dir = get_input_directory() # fallback path + + return os.path.join(base_dir, name) + + +def exists_annotated_filepath(name): + name, base_dir = touch_annotated_filepath(name) + + if base_dir is None: + base_dir = get_input_directory() # fallback path + + filepath = os.path.join(base_dir, name) + return os.path.exists(filepath) + + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths if folder_name in folder_names_and_paths: diff --git a/nodes.py b/nodes.py index 6ca73fa0c..b8b6280d6 100644 --- a/nodes.py +++ b/nodes.py @@ -975,7 +975,7 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -990,7 +990,7 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) @@ -1011,7 +1011,7 @@ class LoadImageMask: FUNCTION = "load_image" def load_image(self, image, channel): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1029,7 +1029,7 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) diff --git a/server.py b/server.py index b5403670f..1c5c17916 100644 --- a/server.py +++ b/server.py @@ -112,13 +112,20 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = folder_paths.get_input_directory() + post = await request.post() + image = post.get("image") + + if post.get("type") is None: + upload_dir = folder_paths.get_input_directory() + elif post.get("type") == "input": + upload_dir = folder_paths.get_input_directory() + elif post.get("type") == "temp": + upload_dir = folder_paths.get_temp_directory() + elif post.get("type") == "output": + upload_dir = folder_paths.get_output_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) - - post = await request.post() - image = post.get("image") if image and image.file: filename = image.filename diff --git a/web/scripts/app.js b/web/scripts/app.js index f158f3457..b3e88d46f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -20,6 +20,12 @@ export class ComfyApp { */ #processingQueue = false; + /** + * Content Clipboard + * @type {serialized node object} + */ + static clipspace = null; + constructor() { this.ui = new ComfyUI(this); @@ -130,6 +136,83 @@ export class ComfyApp { ); } } + + options.push( + { + content: "Copy (Clipspace)", + callback: (obj) => { + var widgets = null; + if(this.widgets) { + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + let img = new Image(); + var imgs = undefined; + if(this.imgs != undefined) { + img.src = this.imgs[0].src; + imgs = [img]; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': imgs, + 'images': this.images + }; + } + }); + + if(ComfyApp.clipspace != null) { + options.push( + { + content: "Paste (Clipspace)", + callback: () => { + if(ComfyApp.clipspace != null) { + if(ComfyApp.clipspace.widgets != null && this.widgets != null) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop) { + prop.value = value; + } + }); + } + + // image paste + if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + var filename = ""; + if(this.images && ComfyApp.clipspace.images) { + this.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.images != undefined) { + const clip_image = ComfyApp.clipspace.images[0]; + if(clip_image.subfolder != '') + filename = `${clip_image.subfolder}/`; + filename += `${clip_image.filename} [${clip_image.type}]`; + } + else if(ComfyApp.clipspace.widgets != undefined) { + const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + if(index_in_clip >= 0) { + filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + } + } + + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { + this.imgs = ComfyApp.clipspace.imgs; + + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } + } + } + this.trigger('changed'); + } + } + } + ); + } }; } From ccad603b2e6862a4a719bc34dc6bd32e65a539ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 16:03:26 -0400 Subject: [PATCH 074/190] Add a way for nodes to validate their own inputs. --- execution.py | 21 +++++++++++---------- folder_paths.py | 6 +++--- nodes.py | 32 +++++++++++++++++++++++--------- web/scripts/app.js | 2 +- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/execution.py b/execution.py index b062deeb1..115efcbda 100644 --- a/execution.py +++ b/execution.py @@ -11,7 +11,6 @@ import torch import nodes import comfy.model_management -import folder_paths def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -250,14 +249,15 @@ def validate_inputs(prompt, item): if "max" in info[1] and val > info[1]["max"]: return (False, "Value bigger than max. {}, {}".format(class_type, x)) - if isinstance(type_input, list): - is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]") - if is_annotated_path: - if not folder_paths.exists_annotated_filepath(val): - return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val)) - - elif val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + ret = obj_class.VALIDATE_INPUTS(**input_data_all) + if ret != True: + return (False, "{}, {}".format(class_type, ret)) + else: + if isinstance(type_input, list): + if val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") def validate_prompt(prompt): @@ -279,7 +279,8 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o) valid = m[0] reason = m[1] - except: + except Exception as e: + print(traceback.format_exc()) valid = False reason = "Parsing error" diff --git a/folder_paths.py b/folder_paths.py index 99a016695..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -71,7 +71,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def touch_annotated_filepath(name): +def annotated_filepath(name): if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -88,7 +88,7 @@ def touch_annotated_filepath(name): def get_annotated_filepath(name, default_dir=None): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: if default_dir is not None: @@ -100,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None): def exists_annotated_filepath(name): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: base_dir = get_input_directory() # fallback path diff --git a/nodes.py b/nodes.py index b8b6280d6..d1133d1d8 100644 --- a/nodes.py +++ b/nodes.py @@ -974,8 +974,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -989,20 +988,27 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + return True + class LoadImageMask: + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() return {"required": {"image": (sorted(os.listdir(input_dir)), ), - "channel": (["alpha", "red", "green", "blue"], ),} + "channel": (s._color_channels, ),} } CATEGORY = "mask" @@ -1010,8 +1016,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1028,13 +1033,22 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image, channel): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + if channel not in s._color_channels: + return "Invalid color channel: {}".format(channel) + + return True + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] diff --git a/web/scripts/app.js b/web/scripts/app.js index b3e88d46f..a161bf40e 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -172,7 +172,7 @@ export class ComfyApp { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); if (prop) { - prop.value = value; + prop.callback(value); } }); } From 0ac319fd81bcecea2aa35743da28088832e44707 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 22:44:38 -0400 Subject: [PATCH 075/190] Don't delete all outputs when execution gets interrupted. --- execution.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/execution.py b/execution.py index 115efcbda..31a208e78 100644 --- a/execution.py +++ b/execution.py @@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = unique_id return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data={}): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return [] - - executed = [] + return for x in inputs: input_data = inputs[x] @@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) if "result" in outputs[unique_id]: outputs[unique_id] = outputs[unique_id]["result"] - return executed + [unique_id] + executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item @@ -158,7 +156,7 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) - executed = [] + executed = set() try: to_execute = [] for x in prompt: @@ -181,12 +179,12 @@ class PromptExecutor: except: valid = False if valid: - executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: print(traceback.format_exc()) to_delete = [] for o in self.outputs: - if o not in current_outputs: + if (o not in current_outputs) and (o not in executed): to_delete += [o] if o in self.old_prompt: d = self.old_prompt.pop(o) @@ -194,11 +192,9 @@ class PromptExecutor: for o in to_delete: d = self.outputs.pop(o) del d - else: - executed = set(executed) + finally: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) - finally: self.server.last_node_id = None if self.server.client_id is not None: self.server.send_sync("executing", { "node": None }, self.server.client_id) From f1b87f50fa9c274f2dd9dbe24b082aa83ef0b028 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 01:50:56 -0400 Subject: [PATCH 076/190] Add hypernetworks path config to extra_model_paths.yaml.example --- extra_model_paths.yaml.example | 1 + 1 file changed, 1 insertion(+) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index ac1ffe9d2..fa5418a68 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -13,6 +13,7 @@ a111: models/ESRGAN models/SwinIR embeddings: embeddings + hypernetworks: models/hypernetworks controlnet: models/ControlNet #other_ui: From 4e345b31f692d5fb89009bf3352c922c2abe30e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 02:36:06 -0400 Subject: [PATCH 077/190] Support all known hypernetworks. --- comfy_extras/nodes_hypernetwork.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index db2f8695c..c08c2c811 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength): activate_output = sd.get('activate_output', False) last_layer_dropout = sd.get('last_layer_dropout', False) - if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: + valid_activation = { + "linear": torch.nn.Identity, + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, + } + + if activation_func not in valid_activation: print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) return None @@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength): keys = attn_weights.keys() linears = filter(lambda a: a.endswith(".weight"), keys) - linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) + linears = list(map(lambda a: a[:-len(".weight")], linears)) layers = [] - for lin_name in linears: + for i in range(len(linears)): + lin_name = linears[i] + last_layer = (i == (len(linears) - 1)) + penultimate_layer = (i == (len(linears) - 2)) + lin_weight = attn_weights['{}.weight'.format(lin_name)] lin_bias = attn_weights['{}.bias'.format(lin_name)] layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) - layers += [layer] + layers.append(layer) + if activation_func != "linear": + if (not last_layer) or (activate_output): + layers.append(valid_activation[activation_func]()) + if is_layer_norm: + layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + if use_dropout: + if (not last_layer) and (not penultimate_layer or last_layer_dropout): + layers.append(torch.nn.Dropout(p=0.3)) output.append(torch.nn.Sequential(*layers)) out[dim] = torch.nn.ModuleList(output) From 463bde66a1d22b02858ac6f148d7fa3e6d9c4322 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 03:08:51 -0400 Subject: [PATCH 078/190] Add hypernetwork example link to readme. Move hypernetwork loader node to loaders. --- README.md | 1 + comfy_extras/nodes_hypernetwork.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bf16006bf..5b6346a67 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/) - Loading full workflows (with seeds) from generated PNG files. - Saving/Loading workflows as Json files. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index c08c2c811..0c7250e43 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -93,7 +93,7 @@ class HypernetworkLoader: RETURN_TYPES = ("MODEL",) FUNCTION = "load_hypernetwork" - CATEGORY = "_for_testing" + CATEGORY = "loaders" def load_hypernetwork(self, model, hypernetwork_name, strength): hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) From d9b1595f8552384dd08374d34c4d4127e0b1a4e6 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:53:10 +0200 Subject: [PATCH 079/190] made sample functions more explicit --- comfy/sample.py | 55 +++++++++++++++++++++---------------------------- nodes.py | 7 +++++-- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 981781b53..84eefcb7b 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,30 +2,25 @@ import torch import comfy.model_management -def prepare_noise(latent, seed): - """creates random noise given a LATENT and a seed""" - latent_image = latent["samples"] - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] - +def prepare_noise(latent_image, seed, skip=0): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ generator = torch.manual_seed(seed) - for i in range(batch_index): + for _ in range(skip): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise -def create_mask(latent, noise): - """creates a mask for a given LATENT and noise""" - noise_mask = None +def prepare_mask(noise_mask, noise): + """ensures noise mask is of proper dimensions""" device = comfy.model_management.get_torch_device() - if "noise_mask" in latent: - noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) - noise_mask = noise_mask.to(device) + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) return noise_mask def broadcast_cond(cond, noise): @@ -40,22 +35,20 @@ def broadcast_cond(cond, noise): copy += [[t] + p[1:]] return copy -def load_c_nets(positive, negative): - """loads control nets in positive and negative conditioning""" - def get_models(cond): - models = [] - for c in cond: - if 'control' in c[1]: - models += [c[1]['control']] - if 'gligen' in c[1]: - models += [c[1]['gligen'][1]] - return models - - return get_models(positive) + get_models(negative) +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c[1]: + models += [c[1][model_type]] + return models def load_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" - models = load_c_nets(positive, negative) + models = [] + models += get_models_from_cond(positive, "control") + models += get_models_from_cond(negative, "control") + models += get_models_from_cond(positive, "gligen") + models += get_models_from_cond(negative, "gligen") comfy.model_management.load_controlnet_gpu(models) return models diff --git a/nodes.py b/nodes.py index b8c6d350f..f9bedc97e 100644 --- a/nodes.py +++ b/nodes.py @@ -747,9 +747,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - noise = comfy.sample.prepare_noise(latent, seed) + skip = latent["batch_index"] if "batch_index" in latent else 0 + noise = comfy.sample.prepare_noise(latent_image, seed, skip) - noise_mask = comfy.sample.create_mask(latent, noise) + noise_mask = None + if "noise_mask" in latent: + noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise) real_model = None comfy.model_management.load_model_gpu(model) From c8c9926eeb0b25dba86f3d9e574e8527c090fc37 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 24 Apr 2023 11:55:44 +0100 Subject: [PATCH 080/190] Add progress to vae decode tiled --- comfy/sd.py | 12 +++++++++--- comfy/utils.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 92dbb931d..2aadefadc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,6 +1,7 @@ import torch import contextlib import copy +from tqdm.auto import tqdm import sd1_clip import sd2_clip @@ -437,11 +438,16 @@ class VAE: self.device = device def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): + it_1 = -(samples.shape[2] // -(tile_y * 2 - overlap)) * -(samples.shape[3] // -(tile_x // 2 - overlap)) + it_2 = -(samples.shape[2] // -(tile_y // 2 - overlap)) * -(samples.shape[3] // -(tile_x * 2 - overlap)) + it_3 = -(samples.shape[2] // -(tile_y - overlap)) * -(samples.shape[3] // -(tile_x - overlap)) + pbar = tqdm(total=samples.shape[0] * (it_1 + it_2 + it_3)) + decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403c..c7c6a08c5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -63,7 +63,7 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3): +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") for b in range(samples.shape[0]): s = samples[b:b+1] @@ -83,6 +83,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask + if pbar is not None: + pbar.update(1) output[b:b+1] = out/out_div return output From 0b07b2cc0f94fc2b8ebe656dfb3768c6f67866f1 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Mon, 24 Apr 2023 21:47:57 +0200 Subject: [PATCH 081/190] gligen tuple --- comfy/sample.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 84eefcb7b..09ab20cd2 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,11 +44,10 @@ def get_models_from_cond(cond, model_type): def load_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" - models = [] - models += get_models_from_cond(positive, "control") - models += get_models_from_cond(negative, "control") - models += get_models_from_cond(positive, "gligen") - models += get_models_from_cond(negative, "gligen") + control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") + gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") + gligen = [x[1] for x in gligen] + models = control_nets + gligen comfy.model_management.load_controlnet_gpu(models) return models From 36acce58e71bbe1bf835c2ec380dc7ac0c5b4752 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 18:13:18 -0400 Subject: [PATCH 082/190] Auto increase the size of the image upload widget when there's an image. --- web/scripts/widgets.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 2acc5f2c0..238ad59dd 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -270,6 +270,9 @@ export const ComfyWidgets = { app.graph.setDirtyCanvas(true); }; img.src = `/view?filename=${name}&type=input`; + if ((node.size[1] - node.imageOffset) < 100) { + node.size[1] = 250 + node.imageOffset; + } } // Add our own callback to the combo widget to render an image when it changes From 7983b3a975c26b93601c8b6fa9a0a333b35794bd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 22:45:35 -0400 Subject: [PATCH 083/190] This is cleaner this way. --- comfy/samplers.py | 59 ++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 46bdb82a0..26597ebba 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -400,38 +400,6 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds -def calculate_sigmas(model, steps, scheduler, sampler): - """ - Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique - """ - if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)): - model = CFGNoisePredictor(model) - if model.inner_model.parameterization == "v": - model = CompVisVDenoiser(model, quantize=True) - else: - model = k_diffusion_external.CompVisDenoiser(model, quantize=True) - - sigmas = None - - discard_penultimate_sigma = False - if sampler in ['dpm_2', 'dpm_2_ancestral']: - steps += 1 - discard_penultimate_sigma = True - - if scheduler == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max)) - elif scheduler == "normal": - sigmas = model.get_sigmas(steps) - elif scheduler == "simple": - sigmas = simple_scheduler(model, steps) - elif scheduler == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) - else: - print("error invalid scheduler", scheduler) - - if discard_penultimate_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - return sigmas class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] @@ -461,13 +429,36 @@ class KSampler: self.denoise = denoise self.model_options = model_options + def calculate_sigmas(self, steps): + sigmas = None + + discard_penultimate_sigma = False + if self.sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + if self.scheduler == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "normal": + sigmas = self.model_wrap.get_sigmas(steps) + elif self.scheduler == "simple": + sigmas = simple_scheduler(self.model_wrap, steps) + elif self.scheduler == "ddim_uniform": + sigmas = ddim_scheduler(self.model_wrap, steps) + else: + print("error invalid scheduler", self.scheduler) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas + def set_steps(self, steps, denoise=None): self.steps = steps if denoise is None or denoise > 0.9999: - self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device) + self.sigmas = self.calculate_sigmas(steps).to(self.device) else: new_steps = int(steps/denoise) - sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device) + sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] From c50208a703c6eba2363b08c4cb62e903a3012710 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 23:25:51 -0400 Subject: [PATCH 084/190] Refactor more code to sample.py --- comfy/sample.py | 47 ++++++++++++++++++++++++++++++++++++----------- nodes.py | 28 ++++------------------------ 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 09ab20cd2..d6848f9d5 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,5 +1,6 @@ import torch import comfy.model_management +import comfy.samplers def prepare_noise(latent_image, seed, skip=0): @@ -13,24 +14,22 @@ def prepare_noise(latent_image, seed, skip=0): noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise -def prepare_mask(noise_mask, noise): +def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" - device = comfy.model_management.get_torch_device() - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(shape[2], shape[3]), mode="bilinear") noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * shape[0]) noise_mask = noise_mask.to(device) return noise_mask -def broadcast_cond(cond, noise): - """broadcasts conditioning to the noise batch size""" - device = comfy.model_management.get_torch_device() +def broadcast_cond(cond, batch, device): + """broadcasts conditioning to the batch size""" copy = [] for p in cond: t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) + if t.shape[0] < batch: + t = torch.cat([t] * batch) t = t.to(device) copy += [[t] + p[1:]] return copy @@ -54,4 +53,30 @@ def load_additional_models(positive, negative): def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: - m.cleanup() \ No newline at end of file + m.cleanup() + +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None): + device = comfy.model_management.get_torch_device() + + if noise_mask is not None: + noise_mask = prepare_mask(noise_mask, noise.shape, device) + + real_model = None + comfy.model_management.load_model_gpu(model) + real_model = model.model + + noise = noise.to(device) + latent_image = latent_image.to(device) + + positive_copy = broadcast_cond(positive, noise.shape[0], device) + negative_copy = broadcast_cond(negative, noise.shape[0], device) + + models = load_additional_models(positive, negative) + + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas) + samples = samples.cpu() + + cleanup_additional_models(models) + return samples diff --git a/nodes.py b/nodes.py index f787fcf8a..0083f6ef8 100644 --- a/nodes.py +++ b/nodes.py @@ -752,31 +752,11 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = None if "noise_mask" in latent: - noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise) - - real_model = None - comfy.model_management.load_model_gpu(model) - real_model = model.model - - noise = noise.to(device) - latent_image = latent_image.to(device) - - positive_copy = comfy.sample.broadcast_cond(positive, noise) - negative_copy = comfy.sample.broadcast_cond(negative, noise) - - models = comfy.sample.load_additional_models(positive, negative) - - if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - else: - #other samplers - pass - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) - samples = samples.cpu() - - comfy.sample.cleanup_additional_models(models) + noise_mask = latent["noise_mask"] + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask) out = latent.copy() out["samples"] = samples return (out, ) From aa57136dae83887e005ab6b0222dce4667b61bee Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Apr 2023 01:12:40 -0400 Subject: [PATCH 085/190] Some fixes to the batch masks PR. --- comfy/sample.py | 7 ++++--- nodes.py | 10 +++------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index d6848f9d5..5e4d26142 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,7 +1,7 @@ import torch import comfy.model_management import comfy.samplers - +import math def prepare_noise(latent_image, seed, skip=0): """ @@ -16,10 +16,11 @@ def prepare_noise(latent_image, seed, skip=0): def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(shape[2], shape[3]), mode="bilinear") + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * shape[0]) + if noise_mask.shape[0] < shape[0]: + noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]] noise_mask = noise_mask.to(device) return noise_mask diff --git a/nodes.py b/nodes.py index b0b61d676..0a9513bed 100644 --- a/nodes.py +++ b/nodes.py @@ -172,16 +172,12 @@ class VAEEncodeForInpaint: def encode(self, vae, pixels, mask): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 - if len(mask.shape) < 3: - mask = mask.unsqueeze(0).unsqueeze(0) - elif len(mask.shape) < 4: - mask = mask.unsqueeze(1) - mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - mask = mask[:,:x,:y,:] + mask = mask[:,:,:x,:y] #grow mask by a few pixels to keep things seamless in latent space kernel_tensor = torch.ones((1, 1, 6, 6)) @@ -193,7 +189,7 @@ class VAEEncodeForInpaint: pixels[:,:,:,i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, ) + return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) class CheckpointLoader: @classmethod From 07194297fd41729f8b95352a710b9039ca2c99e8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Apr 2023 14:02:17 -0400 Subject: [PATCH 086/190] Python 3.7 support. --- comfy_extras/chainner_models/architecture/block.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 1abe1ed8f..214642cc4 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -4,7 +4,10 @@ from __future__ import annotations from collections import OrderedDict -from typing import Literal +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal import torch import torch.nn as nn From ee3a12d283d76212f6771a9cace21d4a469c1ee8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Apr 2023 19:18:50 -0400 Subject: [PATCH 087/190] Update litegraph from upstream. --- web/lib/litegraph.core.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4189a48c0..20ec35476 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9953,11 +9953,11 @@ LGraphNode.prototype.executeAction = function(action) } break; case "slider": - var range = w.options.max - w.options.min; + var old_value = w.value; var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1); if(w.options.read_only) break; w.value = w.options.min + (w.options.max - w.options.min) * nvalue; - if (w.callback) { + if (old_value != w.value) { setTimeout(function() { inner_value_change(w, w.value); }, 20); @@ -10044,7 +10044,7 @@ LGraphNode.prototype.executeAction = function(action) if (event.click_time < 200 && delta == 0) { this.prompt("Value",w.value,function(v) { // check if v is a valid equation or a number - if (/^[0-9+\-*/()\s]+$/.test(v)) { + if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) { try {//solve the equation if possible v = eval(v); } catch (e) { } From 54251ad85e484d4e36df849dcd529837c775d690 Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Wed, 26 Apr 2023 01:22:36 -0400 Subject: [PATCH 088/190] Colored MultilineWidget (#524) * fixes colors and z-index * light mode fix * Update widgets.js --- web/scripts/widgets.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 238ad59dd..c0e73ffa1 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -136,9 +136,11 @@ function addMultilineWidget(node, name, opts, app) { left: `${t.a * margin + t.e}px`, top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, + background: (!node.color)?'':node.color, height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, position: "absolute", - zIndex: 1, + color: (!node.color)?'':'white', + zIndex: app.graph._nodes.indexOf(node), fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; From 951c0c2bbe11e48956a7c619faf0c2cc6e3abff5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Apr 2023 02:05:57 -0400 Subject: [PATCH 089/190] Don't keep cached outputs for removed nodes. --- execution.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/execution.py b/execution.py index 31a208e78..2c97e70d2 100644 --- a/execution.py +++ b/execution.py @@ -152,6 +152,15 @@ class PromptExecutor: self.server.client_id = None with torch.inference_mode(): + #delete cached outputs if nodes don't exist for them + to_delete = [] + for o in self.outputs: + if o not in prompt: + to_delete += [o] + for o in to_delete: + d = self.outputs.pop(o) + del d + for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) From 3a1f9dba20c89038b71d6ff74d4e600d375283b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Apr 2023 02:13:56 -0400 Subject: [PATCH 090/190] If IS_CHANGED returns exception delete the output instead of crashing. --- execution.py | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/execution.py b/execution.py index 2c97e70d2..c19c10bc6 100644 --- a/execution.py +++ b/execution.py @@ -97,40 +97,44 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item is_changed_old = '' is_changed = '' + to_delete = False if hasattr(class_def, 'IS_CHANGED'): if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: - is_changed = class_def.IS_CHANGED(**input_data_all) - prompt[unique_id]['is_changed'] = is_changed + try: + is_changed = class_def.IS_CHANGED(**input_data_all) + prompt[unique_id]['is_changed'] = is_changed + except: + to_delete = True else: is_changed = prompt[unique_id]['is_changed'] if unique_id not in outputs: return True - to_delete = False - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] + if not to_delete: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: + to_delete = True + elif inputs == old_prompt[unique_id]['inputs']: + for x in inputs: + input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id in outputs: + to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) + else: + to_delete = True + if to_delete: + break + else: + to_delete = True if to_delete: d = outputs.pop(unique_id) From 5a971cecdbacb849340f2ea7b3bcd80cc6032d1a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 27 Apr 2023 04:38:44 -0400 Subject: [PATCH 091/190] Add callback to sampler function. Callback format is: callback(step, x0, x) --- comfy/extra_samplers/uni_pc.py | 6 ++++-- comfy/sample.py | 4 ++-- comfy/samplers.py | 22 ++++++++++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index e96cfc93a..2952be62d 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -712,7 +712,7 @@ class UniPC: def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, corrector=False, + atol=0.0078, rtol=0.05, corrector=False, callback=None ): t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start @@ -766,6 +766,8 @@ class UniPC: if model_x is None: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x + if callback is not None: + callback(step_index, model_prev_list[-1], x) else: raise NotImplementedError() if denoise_to_zero: @@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True) + x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/sample.py b/comfy/sample.py index 5e4d26142..f4132bbed 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -56,7 +56,7 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) samples = samples.cpu() cleanup_additional_models(models) diff --git a/comfy/samplers.py b/comfy/samplers.py index 26597ebba..fc19ddcfc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -462,7 +462,7 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -527,9 +527,9 @@ class KSampler: with precision_scope(model_management.get_autocast_device(self.device)): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): @@ -537,6 +537,11 @@ class KSampler: noise_mask = None if denoise_mask is not None: noise_mask = 1.0 - denoise_mask + + ddim_callback = None + if callback is not None: + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) + sampler = DDIMSampler(self.model, device=self.device) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) @@ -550,6 +555,7 @@ class KSampler: eta=0.0, x_T=z_enc, x0=latent_image, + img_callback=ddim_callback, denoise_function=sampling_function, extra_args=extra_args, mask=noise_mask, @@ -563,13 +569,17 @@ class KSampler: noise = noise * sigmas[0] + k_callback = None + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) + if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args) + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) return samples.to(torch.float32) From e958dfdd4d34ad160c50a32e01b5ce08c4e62a29 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 27 Apr 2023 10:59:47 -0400 Subject: [PATCH 092/190] Make notebook work on python3.7 --- notebooks/comfyui_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c1982d8be..fecfa6707 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" + "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" ] }, { From e214c917ae889b278a05fa6e8b8c42d2cc8818fa Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Tue, 25 Apr 2023 00:15:25 -0700 Subject: [PATCH 093/190] Add Condition by Mask node This PR adds support for a Condition by Mask node. This node allows conditioning to be limited to a non-rectangle area. --- comfy/samplers.py | 88 +++++++++++++++++++++++++++++++++++++++-------- nodes.py | 28 +++++++++++++++ 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index fc19ddcfc..6fa754b90 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -23,21 +24,34 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - mult = torch.ones_like(input_x) * strength + if 'mask' in cond[1]: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask = cond[1]['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if mask.shape[0] != input_x.shape[0]: + mask = mask.repeat(input_x.shape[0], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in cond[1]: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} conditionning['c_crossattn'] = cond[0] if cond_concat_in is not None and len(cond_concat_in) > 0: @@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def resolve_cond_masks(conditions, h, w, device): + # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. + # While we're doing this, we can also resolve the mask device and scaling for performance reasons + for i in range(len(conditions)): + c = conditions[i] + if 'mask' in c[1]: + mask = c[1]['mask'] + mask = mask.to(device=device) + modified = c[1].copy() + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if mask.shape[2] != h or mask.shape[3] != w: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + + if 'area' not in modified: + bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) + if torch.max(bounds) == 0: + # Handle the edge-case of an all black mask (where masks_to_boxes would error) + area = (0, 0, 0, 0) + else: + box = masks_to_boxes(bounds)[0].type(torch.int) + H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) + # Make sure the height and width are divisible by 8 + if X % 8 != 0: + newx = X // 8 * 8 + W = W + (X - newx) + X = newx + if Y % 8 != 0: + newy = Y // 8 * 8 + H = H + (Y - newy) + Y = newy + if H % 8 != 0: + H = H + (8 - (H % 8)) + if W % 8 != 0: + W = W + (8 - (W % 8)) + area = (int(H), int(W), int(Y), (X)) + modified['area'] = area + + modified['mask'] = mask + conditions[i] = [c[0], modified] + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -461,7 +516,6 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas @@ -484,6 +538,10 @@ class KSampler: positive = positive[:] negative = negative[:] + + resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) diff --git a/nodes.py b/nodes.py index 0a9513bed..be02f4676 100644 --- a/nodes.py +++ b/nodes.py @@ -85,6 +85,32 @@ class ConditioningSetArea: c.append(n) return (c, ) +class ConditioningSetMask: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + c = [] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + for t in conditioning: + n = [t[0], t[1].copy()] + _, h, w = mask.shape + n[1]['mask'] = mask + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) + return (c, ) + class VAEDecode: def __init__(self, device="cpu"): self.device = device @@ -1115,6 +1141,7 @@ NODE_CLASS_MAPPINGS = { "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, + "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, @@ -1164,6 +1191,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", "ConditioningSetArea": "Conditioning (Set Area)", + "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", From 27bf9392ac1ef07776d31895b748c7ea84969115 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 08:35:20 -0400 Subject: [PATCH 094/190] Switch stable standalone dependencies to stable xformers. Switch nightly standalone to cu121. --- .github/workflows/windows_release_cu118_dependencies_2.yml | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml index a88449527..42adee9e7 100644 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -17,7 +17,7 @@ jobs: - shell: bash run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 291d754e3..32d2f320b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -30,7 +30,7 @@ jobs: echo 'import site' >> ./python310._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python310._pth From e543ecad6991fc7e71dd2042b439aefb9c0722de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 08:50:12 -0400 Subject: [PATCH 095/190] Fix the nightly build not being packaged correctly. --- .ci/nightly/update_windows/update.py | 65 ------------------- .ci/nightly/update_windows/update_comfyui.bat | 2 - ...update_comfyui_and_python_dependencies.bat | 2 +- .../README_VERY_IMPORTANT.txt | 27 -------- .ci/nightly/windows_base_files/run_cpu.bat | 2 - .../windows_release_nightly_pytorch.yml | 2 + 6 files changed, 3 insertions(+), 97 deletions(-) delete mode 100755 .ci/nightly/update_windows/update.py delete mode 100755 .ci/nightly/update_windows/update_comfyui.bat delete mode 100755 .ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt delete mode 100755 .ci/nightly/windows_base_files/run_cpu.bat diff --git a/.ci/nightly/update_windows/update.py b/.ci/nightly/update_windows/update.py deleted file mode 100755 index c09f29a80..000000000 --- a/.ci/nightly/update_windows/update.py +++ /dev/null @@ -1,65 +0,0 @@ -import pygit2 -from datetime import datetime -import sys - -def pull(repo, remote_name='origin', branch='master'): - for remote in repo.remotes: - if remote.name == remote_name: - remote.fetch() - remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target - merge_result, _ = repo.merge_analysis(remote_master_id) - # Up to date, do nothing - if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: - return - # We can just fastforward - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: - repo.checkout_tree(repo.get(remote_master_id)) - try: - master_ref = repo.lookup_reference('refs/heads/%s' % (branch)) - master_ref.set_target(remote_master_id) - except KeyError: - repo.create_branch(branch, repo.get(remote_master_id)) - repo.head.set_target(remote_master_id) - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL: - repo.merge(remote_master_id) - - if repo.index.conflicts is not None: - for conflict in repo.index.conflicts: - print('Conflicts found in:', conflict[0].path) - raise AssertionError('Conflicts, ahhhhh!!') - - user = repo.default_signature - tree = repo.index.write_tree() - commit = repo.create_commit('HEAD', - user, - user, - 'Merge!', - tree, - [repo.head.target, remote_master_id]) - # We need to do this or git CLI will think we are still merging. - repo.state_cleanup() - else: - raise AssertionError('Unknown merge analysis result') - - -repo = pygit2.Repository(str(sys.argv[1])) -ident = pygit2.Signature('comfyui', 'comfy@ui') -try: - print("stashing current changes") - repo.stash(ident) -except KeyError: - print("nothing to stash") -backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) -print("creating backup branch: {}".format(backup_branch_name)) -repo.branches.local.create(backup_branch_name, repo.head.peel()) - -print("checking out master branch") -branch = repo.lookup_branch('master') -ref = repo.lookup_reference(branch.name) -repo.checkout(ref) - -print("pulling latest changes") -pull(repo) - -print("Done!") - diff --git a/.ci/nightly/update_windows/update_comfyui.bat b/.ci/nightly/update_windows/update_comfyui.bat deleted file mode 100755 index 60d1e694f..000000000 --- a/.ci/nightly/update_windows/update_comfyui.bat +++ /dev/null @@ -1,2 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -pause diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c5e0c6be7..c345a6992 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt deleted file mode 100755 index 656b9db43..000000000 --- a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt +++ /dev/null @@ -1,27 +0,0 @@ -HOW TO RUN: - -if you have a NVIDIA gpu: - -run_nvidia_gpu.bat - - - -To run it in slow CPU mode: - -run_cpu.bat - - - -IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints - -You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt - - - -RECOMMENDED WAY TO UPDATE: -To update the ComfyUI code: update\update_comfyui.bat - - - -To update ComfyUI with the python dependencies: -update\update_comfyui_and_python_dependencies.bat diff --git a/.ci/nightly/windows_base_files/run_cpu.bat b/.ci/nightly/windows_base_files/run_cpu.bat deleted file mode 100755 index c3ba41721..000000000 --- a/.ci/nightly/windows_base_files/run_cpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build -pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 32d2f320b..4d686ded8 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -46,6 +46,8 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ + cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ cd .. From ab9a9deff48b5780bd105dfd6d19f5f8333ef608 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 09:03:39 -0400 Subject: [PATCH 096/190] Fix nightly CI builds. No cu121 builds for windows yet. --- .../update_windows/update_comfyui_and_python_dependencies.bat | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c345a6992..b4989534f 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 4d686ded8..f23cae6d5 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -30,7 +30,7 @@ jobs: echo 'import site' >> ./python310._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python310._pth From 3baded9892a6ac02f57caaf68053791ec0e14c5a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 14:28:57 -0400 Subject: [PATCH 097/190] Basic torch_directml support. Use --directml to use it. --- comfy/cli_args.py | 1 + comfy/model_management.py | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b24054ce0..05b9c5e08 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") +parser.add_argument("--directml", action="store_true", help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 6e3a03530..339111c4d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -20,6 +20,13 @@ total_vram_available_mb = -1 accelerate_enabled = False xpu_available = False +directml_enabled = False +if args.directml: + import torch_directml + print("Using directml") + directml_enabled = True + # torch_directml.disable_tiled_resources(True) + try: import torch try: @@ -217,6 +224,9 @@ def unload_if_low_vram(model): def get_torch_device(): global xpu_available + global directml_enabled + if directml_enabled: + return torch_directml.device() if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: @@ -234,8 +244,14 @@ def get_autocast_device(dev): def xformers_enabled(): + global xpu_available + global directml_enabled if vram_state == VRAMState.CPU: return False + if xpu_available: + return False + if directml_enabled: + return False return XFORMERS_IS_AVAILABLE @@ -251,6 +267,7 @@ def pytorch_attention_enabled(): def get_free_memory(dev=None, torch_free_too=False): global xpu_available + global directml_enabled if dev is None: dev = get_torch_device() @@ -258,7 +275,10 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - if xpu_available: + if directml_enabled: + mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_torch = mem_free_total + elif xpu_available: mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_torch = mem_free_total else: @@ -293,9 +313,14 @@ def mps_mode(): def should_use_fp16(): global xpu_available + global directml_enabled + if FORCE_FP32: return False + if directml_enabled: + return False + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? From 0306371e54ddb7472622eb43ed2180a109be6e6b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:18:54 -0400 Subject: [PATCH 098/190] Add "Installing" link to top of readme. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 5b6346a67..00b228497 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) +### [Installing](#installing) + ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Fully supports SD1.x and SD2.x From cab80973d187903d9c415cfcc2575e4616befaa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:19:56 -0400 Subject: [PATCH 099/190] Fix Readme. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 00b228497..3b3824714 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) -### [Installing](#installing) +### [Installing ComfyUI](#installing) ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. From 2ca934f7d4df3e4fa5a74172e5bbc1dd5e1a2ff9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:51:35 -0400 Subject: [PATCH 100/190] You can now select the device index with: --directml id Like this for example: --directml 1 --- comfy/cli_args.py | 2 +- comfy/model_management.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 05b9c5e08..764427165 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,7 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") -parser.add_argument("--directml", action="store_true", help="Use torch-directml.") +parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 339111c4d..9497ae7af 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -21,10 +21,15 @@ accelerate_enabled = False xpu_available = False directml_enabled = False -if args.directml: +if args.directml is not None: import torch_directml - print("Using directml") directml_enabled = True + device_index = args.directml + if device_index < 0: + directml_device = torch_directml.device() + else: + directml_device = torch_directml.device(device_index) + print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) try: @@ -226,7 +231,8 @@ def get_torch_device(): global xpu_available global directml_enabled if directml_enabled: - return torch_directml.device() + global directml_device + return directml_device if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: From 056e5545ffafc7c396cd18d0737a9d5e40f81552 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 00:28:48 -0400 Subject: [PATCH 101/190] Don't try to get vram from xpu or cuda when directml is enabled. --- comfy/model_management.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9497ae7af..db5d368e1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -34,13 +34,16 @@ if args.directml is not None: try: import torch - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + if directml_enabled: + total_vram = 4097 #TODO + else: + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + except: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: if total_vram <= 4096: From af02393c2a7134861df57e5843fc17498c65a795 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 29 Apr 2023 00:16:58 -0700 Subject: [PATCH 102/190] Default to sampling entire image By default, when applying a mask to a condition, the entire image will still be used for sampling. The new "set_area_to_bounds" option on the node will allow the user to automatically limit conditioning to the bounds of the mask. I've also removed the dependency on torchvision for calculating bounding boxes. I've taken the opportunity to fix some frustrating details in the other version: 1. An all-0 mask will no longer cause an error 2. Indices are returned as integers instead of floats so they can be used to index into tensors. --- comfy/samplers.py | 42 ++++++++++++++++++++++++++++++++---------- nodes.py | 4 +++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fa754b90..f8701c879 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,7 +6,6 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if mask.shape[0] != input_x.shape[0]: - mask = mask.repeat(input_x.shape[0], 1, 1) + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def get_mask_aabb(masks): + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device, dtype=torch.int) + + b = masks.shape[0] + + bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int) + is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool) + for i in range(b): + mask = masks[i] + if mask.numel() == 0: + continue + if torch.max(mask != 0) == False: + is_empty[i] = True + continue + y, x = torch.where(mask) + bounding_boxes[i, 0] = torch.min(x) + bounding_boxes[i, 1] = torch.min(y) + bounding_boxes[i, 2] = torch.max(x) + bounding_boxes[i, 3] = torch.max(y) + + return bounding_boxes, is_empty + def resolve_cond_masks(conditions, h, w, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons @@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device): if mask.shape[2] != h or mask.shape[3] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) - if 'area' not in modified: + if modified.get("set_area_to_bounds", False): bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) - if torch.max(bounds) == 0: - # Handle the edge-case of an all black mask (where masks_to_boxes would error) - area = (0, 0, 0, 0) + boxes, is_empty = get_mask_aabb(bounds) + if is_empty[0]: + # Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway) + modified['area'] = (8, 8, 0, 0) else: - box = masks_to_boxes(bounds)[0].type(torch.int) + box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) # Make sure the height and width are divisible by 8 if X % 8 != 0: @@ -350,8 +372,8 @@ def resolve_cond_masks(conditions, h, w, device): H = H + (8 - (H % 8)) if W % 8 != 0: W = W + (8 - (W % 8)) - area = (int(H), int(W), int(Y), (X)) - modified['area'] = area + area = (int(H), int(W), int(Y), int(X)) + modified['area'] = area modified['mask'] = mask conditions[i] = [c[0], modified] diff --git a/nodes.py b/nodes.py index be02f4676..12fa7e5a3 100644 --- a/nodes.py +++ b/nodes.py @@ -90,6 +90,7 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), + "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -97,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -105,6 +106,7 @@ class ConditioningSetMask: n = [t[0], t[1].copy()] _, h, w = mask.shape n[1]['mask'] = mask + n[1]['set_area_to_bounds'] = set_area_to_bounds n[1]['strength'] = strength n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma From ffd0f9f417d94bce03ea863131df9e6a86a89ada Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:19:14 +0100 Subject: [PATCH 103/190] Search filter by type --- web/extensions/core/slotDefaults.js | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 3ec605900..9401678b0 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -6,6 +6,7 @@ app.registerExtension({ name: "Comfy.SlotDefaults", suggestionsNumber: null, init() { + LiteGraph.search_filter_enabled = true; LiteGraph.middle_click_slot_add_default_node = true; this.suggestionsNumber = app.ui.settings.addSetting({ id: "Comfy.NodeSuggestions.number", @@ -43,6 +44,14 @@ app.registerExtension({ } if (this.slot_types_default_out[type].includes(nodeId)) continue; this.slot_types_default_out[type].push(nodeId); + + // Input types have to be stored as lower case + // Store each node that can handle this input type + const lowerType = type.toLocaleLowerCase(); + if (!(lowerType in LiteGraph.registered_slot_in_types)) { + LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] }; + } + LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass); } var outputs = nodeData["output"]; @@ -53,6 +62,16 @@ app.registerExtension({ } this.slot_types_default_in[type].push(nodeId); + + // Store each node that can handle this output type + if (!(type in LiteGraph.registered_slot_out_types)) { + LiteGraph.registered_slot_out_types[type] = { nodes: [] }; + } + LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass); + + if(!LiteGraph.slot_types_out.includes(type)) { + LiteGraph.slot_types_out.push(type); + } } var maxNum = this.suggestionsNumber.value; this.setDefaults(maxNum); From 15a4c0db3b11c75350268950d8d0da175e72440d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:29:07 +0100 Subject: [PATCH 104/190] - button hover style - ensure context menu is always above everything --- web/style.css | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/web/style.css b/web/style.css index 2cbf02c0c..eced33d29 100644 --- a/web/style.css +++ b/web/style.css @@ -120,7 +120,7 @@ body { .comfy-menu > button, .comfy-menu-btns button, .comfy-menu .comfy-list button, -.comfy-modal button{ +.comfy-modal button { color: var(--input-text); background-color: var(--comfy-input-bg); border-radius: 8px; @@ -129,6 +129,15 @@ body { margin-top: 2px; } +.comfy-menu > button:hover, +.comfy-menu-btns button:hover, +.comfy-menu .comfy-list button:hover, +.comfy-modal button:hover, +.comfy-settings-btn:hover { + filter: brightness(1.2); + cursor: pointer; +} + .comfy-menu span.drag-handle { width: 10px; height: 20px; @@ -284,4 +293,7 @@ button.comfy-queue-btn { top: 0; right: 2px; } - \ No newline at end of file + + .litecontextmenu { + z-index: 9999 !important; +} \ No newline at end of file From 071011aebed2b636865dacacf6213d6714d6d80c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:06:53 -0400 Subject: [PATCH 105/190] Mask strength should be separate from area strength. --- comfy/samplers.py | 5 ++++- nodes.py | 6 ++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f8701c879..10527fb1c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -26,10 +26,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'mask' in cond[1]: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in cond[1]: + mask_strength = cond[1]["mask_strength"] mask = cond[1]['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) diff --git a/nodes.py b/nodes.py index 12fa7e5a3..b4069c836 100644 --- a/nodes.py +++ b/nodes.py @@ -98,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -107,9 +107,7 @@ class ConditioningSetMask: _, h, w = mask.shape n[1]['mask'] = mask n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['strength'] = strength - n[1]['min_sigma'] = min_sigma - n[1]['max_sigma'] = max_sigma + n[1]['mask_strength'] = strength c.append(n) return (c, ) From c66db067630c57ec037b906b6b3f766d1153522b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:19:14 -0400 Subject: [PATCH 106/190] Make ConditioningSetMask area option a bit more clear. Make ConditioningSetArea override the set_area_to_bounds. --- nodes.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index b4069c836..c9d660738 100644 --- a/nodes.py +++ b/nodes.py @@ -80,6 +80,7 @@ class ConditioningSetArea: n = [t[0], t[1].copy()] n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['strength'] = strength + n[1]['set_area_to_bounds'] = False n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma c.append(n) @@ -90,16 +91,19 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), - "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" CATEGORY = "conditioning" - def append(self, conditioning, mask, set_area_to_bounds, strength): + def append(self, conditioning, mask, set_cond_area, strength): c = [] + set_area_to_bounds = False + if set_cond_area != "default": + set_area_to_bounds = True if len(mask.shape) < 3: mask = mask.unsqueeze(0) for t in conditioning: From 4cea9aecdab6bbd7b5801c64c27368ee3203a9ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:53:03 -0400 Subject: [PATCH 107/190] Make nodes easier to resize. --- web/lib/litegraph.core.js | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 20ec35476..d471c0f50 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,10 +5880,10 @@ LGraphNode.prototype.executeAction = function(action) node.resizable !== false && isInsideRectangle( e.canvasX, e.canvasY, - node.pos[0] + node.size[0] - 5, - node.pos[1] + node.size[1] - 5, - 10, - 10 + node.pos[0] + node.size[0] - 15, + node.pos[1] + node.size[1] - 15, + 20, + 20 ) ) { this.graph.beforeChange(); @@ -6428,10 +6428,10 @@ LGraphNode.prototype.executeAction = function(action) isInsideRectangle( e.canvasX, e.canvasY, - node.pos[0] + node.size[0] - 5, - node.pos[1] + node.size[1] - 5, - 5, - 5 + node.pos[0] + node.size[0] - 15, + node.pos[1] + node.size[1] - 15, + 15, + 15 ) ) { this.canvas.style.cursor = "se-resize"; From a2e18b15046456c86b0d550d515c737f976d03d6 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sun, 30 Apr 2023 18:59:58 +0200 Subject: [PATCH 108/190] allow disabling of progress bar when sampling --- comfy/samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 10527fb1c..1b486f803 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -541,7 +541,7 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -610,9 +610,9 @@ class KSampler: with precision_scope(model_management.get_autocast_device(self.device)): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): @@ -659,10 +659,10 @@ class KSampler: if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) return samples.to(torch.float32) From 20123624933cd559dc903f0b7c97566113018a1b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 13:02:07 -0400 Subject: [PATCH 109/190] Adjust node resize area depending on outputs. --- web/lib/litegraph.core.js | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index d471c0f50..2bc6af0c3 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -3628,6 +3628,18 @@ return size; }; + LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) { + var rows = this.outputs ? this.outputs.length : 1; + var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT; + return isInsideRectangle(canvasX, + canvasY, + this.pos[0] + this.size[0] - 15, + this.pos[1] + Math.max(this.size[1] - 15, outputs_offset), + 20, + 20 + ); + } + /** * returns all the info available about a property of this node. * @@ -5877,14 +5889,7 @@ LGraphNode.prototype.executeAction = function(action) if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && - node.resizable !== false && - isInsideRectangle( e.canvasX, - e.canvasY, - node.pos[0] + node.size[0] - 15, - node.pos[1] + node.size[1] - 15, - 20, - 20 - ) + node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) ) { this.graph.beforeChange(); this.resizing_node = node; @@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action) //Search for corner if (this.canvas) { - if ( - isInsideRectangle( - e.canvasX, - e.canvasY, - node.pos[0] + node.size[0] - 15, - node.pos[1] + node.size[1] - 15, - 15, - 15 - ) - ) { + if (node.inResizeCorner(e.canvasX, e.canvasY)) { this.canvas.style.cursor = "se-resize"; } else { this.canvas.style.cursor = "crosshair"; From 29c8f1a3442aad7d615430f8484b85de995c141c Mon Sep 17 00:00:00 2001 From: FizzleDorf <1fizzledorf@gmail.com> Date: Sun, 30 Apr 2023 17:33:15 -0400 Subject: [PATCH 110/190] Conditioning Average (#495) * first commit * fixed a bunch of things missing in initial commit. * parameters renamed for clarity * renamed node, attempted update cond list * to_strength removed, it is now normalized * removed comments and prints. Attempted to apply to every cond in list again but no luck * fixed repeating frames after batch using deepcopy * Revert "fixed repeating frames after batch using deepcopy" This reverts commit 1086d6a0e1f5c5c9247312872402ff8e60358fe1. * Rewrite addWeighted to use torch.mul iteratively. --------- Co-authored-by: City <125218114+city96@users.noreply.github.com> --- nodes.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/nodes.py b/nodes.py index c9d660738..fc3d2f183 100644 --- a/nodes.py +++ b/nodes.py @@ -59,6 +59,27 @@ class ConditioningCombine: def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) +class ConditioningAverage : + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), + "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "addWeighted" + + CATEGORY = "conditioning" + + def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): + out = [] + for i in range(min(len(conditioning_from),len(conditioning_to))): + t0 = conditioning_from[i] + t1 = conditioning_to[i] + tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength) + n = [tw, t0[1].copy()] + out.append(n) + return (out, ) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): @@ -1143,6 +1164,7 @@ NODE_CLASS_MAPPINGS = { "ImageScale": ImageScale, "ImageInvert": ImageInvert, "ImagePadForOutpaint": ImagePadForOutpaint, + "ConditioningAverage ": ConditioningAverage , "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetMask": ConditioningSetMask, @@ -1194,6 +1216,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPTextEncode": "CLIP Text Encode (Prompt)", "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", + "ConditioningAverage ": "Conditioning (Average)", "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", From 0aa667ed33aae800880153a91c283ac457d0b31c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 17:28:55 -0400 Subject: [PATCH 111/190] Fix ConditioningAverage. --- nodes.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/nodes.py b/nodes.py index fc3d2f183..53e0f74bf 100644 --- a/nodes.py +++ b/nodes.py @@ -62,21 +62,30 @@ class ConditioningCombine: class ConditioningAverage : @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), - "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) + return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), + "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" CATEGORY = "conditioning" - def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): + def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength): out = [] - for i in range(min(len(conditioning_from),len(conditioning_to))): - t0 = conditioning_from[i] - t1 = conditioning_to[i] - tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength) - n = [tw, t0[1].copy()] + + if len(conditioning_from) > 1: + print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + + cond_from = conditioning_from[0][0] + + for i in range(len(conditioning_to)): + t1 = conditioning_to[i][0] + t0 = cond_from[:,:t1.shape[1]] + if t0.shape[1] < t1.shape[1]: + t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1) + + tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength)) + n = [tw, conditioning_to[i][1].copy()] out.append(n) return (out, ) From b04e16ef5a7cd9cbf80d272a455bd34e869a6ec8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 18:19:03 -0400 Subject: [PATCH 112/190] Make default workflow use an existing checkpoint if no SD1.5 checkpoint. --- web/scripts/app.js | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index a161bf40e..ada1708dc 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -971,8 +971,10 @@ export class ComfyApp { loadGraphData(graphData) { this.clean(); + let reset_invalid_values = false; if (!graphData) { graphData = structuredClone(defaultGraph); + reset_invalid_values = true; } const missingNodeTypes = []; @@ -1058,6 +1060,13 @@ export class ComfyApp { } } } + if (reset_invalid_values) { + if (widget.type == "combo") { + if (!widget.options.values.includes(widget.value) && widget.options.values.length > 0) { + widget.value = widget.options.values[0]; + } + } + } } } From 6aae1f497f680355b0e51242c4195cf75803056d Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 1 May 2023 13:16:19 -0400 Subject: [PATCH 113/190] style context menu fix graphdialog background, and palette template --- web/extensions/core/colorPalette.js | 17 +++++++++++++++ web/style.css | 34 ++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 41541a8d8..2f2238a2b 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -232,10 +232,27 @@ app.registerExtension({ "name": "My Color Palette", "colors": { "node_slot": { + }, + "litegraph_base": { + }, + "comfy_base": { } } }; + // Copy over missing keys from default color palette + const defaultColorPalette = colorPalettes[defaultColorPaletteId]; + for (const key in defaultColorPalette.colors.litegraph_base) { + if (!colorPalette.colors.litegraph_base[key]) { + colorPalette.colors.litegraph_base[key] = ""; + } + } + for (const key in defaultColorPalette.colors.comfy_base) { + if (!colorPalette.colors.comfy_base[key]) { + colorPalette.colors.comfy_base[key] = ""; + } + } + return completeColorPalette(colorPalette); }; diff --git a/web/style.css b/web/style.css index eced33d29..6ef3a4c21 100644 --- a/web/style.css +++ b/web/style.css @@ -257,8 +257,11 @@ button.comfy-queue-btn { } } +/* Input popup */ + .graphdialog { min-height: 1em; + background-color: var(--comfy-menu-bg); } .graphdialog .name { @@ -282,18 +285,37 @@ button.comfy-queue-btn { border-radius: 12px 0 0 12px; } +/* Context menu */ + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; - } +} - .litemenu-entry.has_submenu::after { +.litemenu-entry.has_submenu::after { content: ">"; position: absolute; top: 0; right: 2px; - } - - .litecontextmenu { +} + +.litecontextmenu { z-index: 9999 !important; -} \ No newline at end of file +} + +.litegraph.litecontextmenu { + background-color: var(--comfy-menu-bg) !important; + filter: brightness(95%); + color: var(--input-text) !important; +} + +.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { + background-color: var(--comfy-menu-bg) !important; + filter: brightness(155%); + color: var(--input-text) !important; +} + +.litegraph.litecontextmenu .litemenu-entry.submenu { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text) !important; +} From d3293c833947928456cd69a67c5e7d602216f997 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 May 2023 15:47:10 -0400 Subject: [PATCH 114/190] Properly disable all progress bars when disable_pbar=True --- comfy/extra_samplers/uni_pc.py | 8 ++++---- comfy/ldm/models/diffusion/ddim.py | 8 +++++--- comfy/sample.py | 4 ++-- comfy/samplers.py | 3 ++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 2952be62d..78bab5936 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -712,7 +712,7 @@ class UniPC: def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, corrector=False, callback=None + atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False ): t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start @@ -723,7 +723,7 @@ class UniPC: # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): - for step_index in trange(steps): + for step_index in trange(steps, disable=disable_pbar): if self.noise_mask is not None: x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: @@ -835,7 +835,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -879,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) + x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index e00ffd3f5..deab76f21 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -81,6 +81,7 @@ class DDIMSampler(object): extra_args=None, to_zero=True, end_step=None, + disable_pbar=False, **kwargs ): self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) @@ -103,7 +104,8 @@ class DDIMSampler(object): denoise_function=denoise_function, extra_args=extra_args, to_zero=to_zero, - end_step=end_step + end_step=end_step, + disable_pbar=disable_pbar ) return samples, intermediates @@ -185,7 +187,7 @@ class DDIMSampler(object): mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): + ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): device = self.model.betas.device b = shape[0] if x_T is None: @@ -204,7 +206,7 @@ class DDIMSampler(object): total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] # print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) + iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar) for i, step in enumerate(iterator): index = total_steps - i - 1 diff --git a/comfy/sample.py b/comfy/sample.py index f4132bbed..bd38585ac 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -56,7 +56,7 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) samples = samples.cpu() cleanup_additional_models(models) diff --git a/comfy/samplers.py b/comfy/samplers.py index 1b486f803..b30fc3d9b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -643,7 +643,8 @@ class KSampler: extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, - end_step=sigmas.shape[0] - 1) + end_step=sigmas.shape[0] - 1, + disable_pbar=disable_pbar) else: extra_args["denoise_mask"] = denoise_mask From 81bee39ca0540aa7bbab275bb6bb9f156e72addd Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 1 May 2023 15:57:10 -0400 Subject: [PATCH 115/190] style everything styles searchbox, should be actually everything --- web/style.css | 43 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/web/style.css b/web/style.css index 6ef3a4c21..df220cc02 100644 --- a/web/style.css +++ b/web/style.css @@ -299,23 +299,52 @@ button.comfy-queue-btn { right: 2px; } -.litecontextmenu { +.litegraph.litecontextmenu, +.litegraph.litecontextmenu.dark { z-index: 9999 !important; -} - -.litegraph.litecontextmenu { background-color: var(--comfy-menu-bg) !important; filter: brightness(95%); - color: var(--input-text) !important; } .litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { background-color: var(--comfy-menu-bg) !important; filter: brightness(155%); + color: var(--input-text); +} + +.litegraph.litecontextmenu .litemenu-entry.submenu, +.litegraph.litecontextmenu.dark .litemenu-entry.submenu { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text); +} + +.litegraph.litecontextmenu input { + background-color: var(--comfy-input-bg) !important; color: var(--input-text) !important; } -.litegraph.litecontextmenu .litemenu-entry.submenu { +/* Search box */ + +.litegraph.litesearchbox { + z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; - color: var(--input-text) !important; + overflow: hidden; +} + +.litegraph.litesearchbox input, +.litegraph.litesearchbox select { + background-color: var(--comfy-input-bg) !important; + color: var(--input-text); +} + +.litegraph.lite-search-item { + color: var(--input-text); + background-color: var(--comfy-input-bg); + filter: brightness(80%); + padding-left: 0.2em; +} + +.litegraph.lite-search-item.generic_type { + color: var(--input-text); + filter: brightness(50%); } From 9c335a553fd9f8d4c3c97eeaec5dca89a2a900f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 May 2023 18:11:58 -0400 Subject: [PATCH 116/190] LoKR support. --- comfy/sd.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 92dbb931d..3eb50cc95 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -111,6 +111,8 @@ def load_lora(path, to_load): loaded_keys.add(A_name) loaded_keys.add(B_name) + + ######## loha hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x) @@ -132,6 +134,54 @@ def load_lora(path, to_load): loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_b_name) + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + for x in lora.keys(): if x not in loaded_keys: print("lora key not loaded", x) @@ -315,6 +365,33 @@ class ModelPatcher: final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float(), w2_b.float()) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) else: #loha w1a = v[0] w1b = v[1] From 35f636b6c741045d25d645ecb95a6e8e2c04d6eb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 00:53:15 -0400 Subject: [PATCH 117/190] Expose grow_mask_by in VAEEncodeForInpaint. The mask is dilated by grow_mask_by pixels after being applied to the pixel space image. This helps reduce seams caused by inpainting. Higher value means less seams. --- nodes.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 53e0f74bf..4f0b7bfe8 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import sys import json import hashlib import traceback +import math from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -223,13 +224,13 @@ class VAEEncodeForInpaint: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} + return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "latent/inpaint" - def encode(self, vae, pixels, mask): + def encode(self, vae, pixels, mask, grow_mask_by=6): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") @@ -240,8 +241,14 @@ class VAEEncodeForInpaint: mask = mask[:,:,:x,:y] #grow mask by a few pixels to keep things seamless in latent space - kernel_tensor = torch.ones((1, 1, 6, 6)) - mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) + if grow_mask_by == 0: + mask_erosion = mask + else: + kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by)) + padding = math.ceil((grow_mask_by - 1) / 2) + + mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1) + m = (1.0 - mask.round()).squeeze(1) for i in range(3): pixels[:,:,:,i] -= 0.5 From a307c3f12c7816885802ae4ad2ffc6a14e550540 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 09:40:57 -0400 Subject: [PATCH 118/190] Update nightly pytorch standalone to python 3.11.3 cu121. --- .../update_comfyui_and_python_dependencies.bat | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index b4989534f..94f5d1023 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index f23cae6d5..b6a18ec0a 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -19,21 +19,21 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v4 with: - python-version: '3.10.9' + python-version: '3.11.3' - shell: bash run: | cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded - echo 'import site' >> ./python310._pth + echo 'import site' >> ./python311._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* - sed -i '1i../ComfyUI' ./python310._pth + sed -i '1i../ComfyUI' ./python311._pth cd .. From 66c8aa5c3ee601dbca396f66fe86703977b908b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 13:31:43 -0400 Subject: [PATCH 119/190] Make unet work with any input shape. --- .../modules/diffusionmodules/openaimodel.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c69c8567..0393dc013 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context=None, transformer_options={}): + def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) else: x = layer(x) return x @@ -105,14 +107,21 @@ class Upsample(nn.Module): if use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - def forward(self, x): + def forward(self, x, output_shape=None): + print("upsample", output_shape) assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] + if output_shape is not None: + shape[1] = output_shape[3] + shape[2] = output_shape[4] else: - x = F.interpolate(x, scale_factor=2, mode="nearest") + shape = [x.shape[2] * 2, x.shape[3] * 2] + if output_shape is not None: + shape[0] = output_shape[2] + shape[1] = output_shape[3] + + x = F.interpolate(x, size=shape, mode="nearest") if self.use_conv: x = self.conv(x) return x @@ -813,9 +822,14 @@ class UNetModel(nn.Module): ctrl = control['output'].pop() if ctrl is not None: hsp += ctrl + h = th.cat([h, hsp], dim=1) del hsp - h = module(h, emb, context, transformer_options) + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + h = module(h, emb, context, transformer_options, output_shape) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) From ba8a4c3667eda95649d8bfa906186d42e9ac6835 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 14:16:27 -0400 Subject: [PATCH 120/190] Change latent resolution step to 8. --- .../modules/diffusionmodules/openaimodel.py | 1 - nodes.py | 72 +++++++++---------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 0393dc013..25309dbd7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -108,7 +108,6 @@ class Upsample(nn.Module): self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) def forward(self, x, output_shape=None): - print("upsample", output_shape) assert x.shape[1] == self.channels if self.dims == 3: shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] diff --git a/nodes.py b/nodes.py index 4f0b7bfe8..80d508854 100644 --- a/nodes.py +++ b/nodes.py @@ -94,10 +94,10 @@ class ConditioningSetArea: @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), + "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -188,16 +188,21 @@ class VAEEncode: CATEGORY = "latent" - def encode(self, vae, pixels): - x = (pixels.shape[1] // 64) * 64 - y = (pixels.shape[2] // 64) * 64 + @staticmethod + def vae_encode_crop_pixels(pixels): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:,:x,:y,:] + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + + def encode(self, vae, pixels): + pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) - return ({"samples":t}, ) - class VAEEncodeTiled: def __init__(self, device="cpu"): self.device = device @@ -211,13 +216,10 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels): - x = (pixels.shape[1] // 64) * 64 - y = (pixels.shape[2] // 64) * 64 - if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:,:x,:y,:] + pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3]) - return ({"samples":t}, ) + class VAEEncodeForInpaint: def __init__(self, device="cpu"): self.device = device @@ -231,14 +233,16 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 64) * 64 - y = (pixels.shape[2] // 64) * 64 + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:,:x,:y,:] - mask = mask[:,:,:x,:y] + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] #grow mask by a few pixels to keep things seamless in latent space if grow_mask_by == 0: @@ -610,8 +614,8 @@ class EmptyLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -649,8 +653,8 @@ class LatentUpscale: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), "crop": (s.crop_methods,)}} RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -752,8 +756,8 @@ class LatentCrop: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }} @@ -778,16 +782,6 @@ class LatentCrop: new_width = width // 8 to_x = new_width + x to_y = new_height + y - def enforce_image_dim(d, to_d, max_d): - if to_d > max_d: - leftover = (to_d - max_d) % 8 - to_d = max_d - d -= leftover - return (d, to_d) - - #make sure size is always multiple of 64 - x, to_x = enforce_image_dim(x, to_x, samples.shape[3]) - y, to_y = enforce_image_dim(y, to_y, samples.shape[2]) s['samples'] = samples[:,:,y:to_y, x:to_x] return (s,) @@ -1105,10 +1099,10 @@ class ImagePadForOutpaint: return { "required": { "image": ("IMAGE",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), + "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}), } } From 06ad35b4932fe6cc4382d8b1dfa79fef8284362a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 2 May 2023 19:18:07 +0100 Subject: [PATCH 121/190] added progress to encode + upscale --- comfy/sd.py | 12 +++++++++--- comfy_extras/nodes_upscale_model.py | 8 +++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 2aadefadc..06d6c1a56 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -491,9 +491,15 @@ class VAE: model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4) + + it_1 = -(pixel_samples.shape[2] // -(tile_y * 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x // 2 - overlap)) + it_2 = -(pixel_samples.shape[2] // -(tile_y // 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x * 2 - overlap)) + it_3 = -(pixel_samples.shape[2] // -(tile_y - overlap)) * -(pixel_samples.shape[3] // -(tile_x - overlap)) + pbar = tqdm(total=(it_1 + it_2 + it_3)) + + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples /= 3.0 self.first_stage_model = self.first_stage_model.cpu() samples = samples.cpu() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d8754698c..4fc7dcd77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,6 +4,7 @@ from comfy import model_management import torch import comfy.utils import folder_paths +from tqdm.auto import tqdm class UpscaleModelLoader: @classmethod @@ -37,7 +38,12 @@ class ImageUpscaleWithModel: device = model_management.get_torch_device() upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale) + + tile = 128 + 64 + overlap = 8 + its = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + pbar = tqdm(total=its) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) From 93c64afaa92b425fc863b80ee0b7c618705d7d60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 23:00:49 -0400 Subject: [PATCH 122/190] Use sampler callback instead of tqdm hook for progress bar. --- comfy/utils.py | 23 +++++++++++++++++++++++ main.py | 12 ++++-------- nodes.py | 6 +++++- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403c..7f3c3978c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,3 +86,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output + + +PROGRESS_BAR_HOOK = None +def set_progress_bar_global_hook(function): + global PROGRESS_BAR_HOOK + PROGRESS_BAR_HOOK = function + +class ProgressBar: + def __init__(self, total): + global PROGRESS_BAR_HOOK + self.total = total + self.current = 0 + self.hook = PROGRESS_BAR_HOOK + + def update_absolute(self, value): + if value > self.total: + value = self.total + self.current = value + if self.hook is not None: + self.hook(self.current, self.total) + + def update(self, value): + self.update_absolute(self.current + value) diff --git a/main.py b/main.py index 02c700ebc..f369b82f3 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import shutil import threading from comfy.cli_args import args +import comfy.utils if os.name == "nt": import logging @@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): - from tqdm.auto import tqdm - orig_func = getattr(tqdm, "update") - def wrapped_func(*args, **kwargs): - pbar = args[0] - v = orig_func(*args, **kwargs) - server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id) - return v - setattr(tqdm, "update", wrapped_func) + def hook(value, total): + server.send_sync("progress", { "value": value, "max": total}, server.client_id) + comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/nodes.py b/nodes.py index 80d508854..90c943fe3 100644 --- a/nodes.py +++ b/nodes.py @@ -815,9 +815,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x): + pbar.update_absolute(step + 1) + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) out = latent.copy() out["samples"] = samples return (out, ) From 27df74101e6e5bb761364b718d57313388b49182 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 17:33:19 +0100 Subject: [PATCH 123/190] reduce duplication --- comfy/sd.py | 14 +++++--------- comfy/utils.py | 6 ++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 06d6c1a56..87b380b1c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -438,10 +438,8 @@ class VAE: self.device = device def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): - it_1 = -(samples.shape[2] // -(tile_y * 2 - overlap)) * -(samples.shape[3] // -(tile_x // 2 - overlap)) - it_2 = -(samples.shape[2] // -(tile_y // 2 - overlap)) * -(samples.shape[3] // -(tile_x * 2 - overlap)) - it_3 = -(samples.shape[2] // -(tile_y - overlap)) * -(samples.shape[3] // -(tile_x - overlap)) - pbar = tqdm(total=samples.shape[0] * (it_1 + it_2 + it_3)) + steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + pbar = tqdm(total=steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( @@ -492,11 +490,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - it_1 = -(pixel_samples.shape[2] // -(tile_y * 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x // 2 - overlap)) - it_2 = -(pixel_samples.shape[2] // -(tile_y // 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x * 2 - overlap)) - it_3 = -(pixel_samples.shape[2] // -(tile_y - overlap)) * -(pixel_samples.shape[3] // -(tile_x - overlap)) - pbar = tqdm(total=(it_1 + it_2 + it_3)) - + steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + pbar = tqdm(total=steps) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy/utils.py b/comfy/utils.py index c7c6a08c5..82d3aa0d8 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -62,6 +62,12 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) +def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): + it_1 = -(height // -(tile_y * 2 - overlap)) * -(width // -(tile_x // 2 - overlap)) + it_2 = -(height // -(tile_y // 2 - overlap)) * -(width // -(tile_x * 2 - overlap)) + it_3 = -(height // -(tile_y - overlap)) * -(width // -(tile_x - overlap)) + return it_1 + it_2 + it_3 + @torch.inference_mode() def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") From 908dc1d5a8717073f44d136d6d2b4f983ea07d40 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 May 2023 12:58:10 -0400 Subject: [PATCH 124/190] Add a total_steps value to sampler callback. --- comfy/extra_samplers/uni_pc.py | 2 +- comfy/samplers.py | 8 +++++--- comfy/utils.py | 4 +++- nodes.py | 4 ++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 78bab5936..2ff10caf1 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -767,7 +767,7 @@ class UniPC: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x) + callback(step_index, model_prev_list[-1], x, steps) else: raise NotImplementedError() if denoise_to_zero: diff --git a/comfy/samplers.py b/comfy/samplers.py index b30fc3d9b..dcf93cca2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -623,7 +623,8 @@ class KSampler: ddim_callback = None if callback is not None: - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) + total_steps = len(timesteps) - 1 + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) sampler = DDIMSampler(self.model, device=self.device) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) @@ -654,13 +655,14 @@ class KSampler: noise = noise * sigmas[0] k_callback = None + total_steps = len(sigmas) - 1 if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) elif self.sampler == "dpm_adaptive": samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) else: diff --git a/comfy/utils.py b/comfy/utils.py index 7f3c3978c..f1ff97792 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -100,7 +100,9 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK - def update_absolute(self, value): + def update_absolute(self, value, total=None): + if total is not None: + self.total = total if value > self.total: value = self.total self.current = value diff --git a/nodes.py b/nodes.py index 90c943fe3..c2bc36855 100644 --- a/nodes.py +++ b/nodes.py @@ -816,8 +816,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] pbar = comfy.utils.ProgressBar(steps) - def callback(step, x0, x): - pbar.update_absolute(step + 1) + def callback(step, x0, x, total_steps): + pbar.update_absolute(step + 1, total_steps) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, From 8912623ea9929848b813f1aeafee0fa9e1281817 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 18:19:22 +0100 Subject: [PATCH 125/190] use comfy progress bar --- comfy/sd.py | 6 +++--- comfy_extras/nodes_upscale_model.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 32499f600..e4c5282d7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -516,7 +516,7 @@ class VAE: def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - pbar = tqdm(total=steps) + pbar = utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( @@ -568,8 +568,8 @@ class VAE: pixel_samples = pixel_samples.movedim(-1,1).to(self.device) steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - pbar = tqdm(total=steps) - + pbar = utils.ProgressBar(steps) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4fc7dcd77..dfd1994a6 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -41,8 +41,8 @@ class ImageUpscaleWithModel: tile = 128 + 64 overlap = 8 - its = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) - pbar = tqdm(total=its) + steps = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) From 5eeecf3fd5adedfa5a92d3549f77a78be714c2a3 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 18:21:23 +0100 Subject: [PATCH 126/190] remove unused import --- comfy/sd.py | 1 - comfy_extras/nodes_upscale_model.py | 1 - 2 files changed, 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e4c5282d7..d60b908b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,7 +1,6 @@ import torch import contextlib import copy -from tqdm.auto import tqdm import sd1_clip import sd2_clip diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index dfd1994a6..f774b4b77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,7 +4,6 @@ from comfy import model_management import torch import comfy.utils import folder_paths -from tqdm.auto import tqdm class UpscaleModelLoader: @classmethod From fcf513e0b6b599e23b7d6f9bde315be6f991652b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 May 2023 17:48:35 -0400 Subject: [PATCH 127/190] Refactor. --- comfy/sd.py | 6 +++++- comfy/utils.py | 6 ++---- comfy_extras/nodes_upscale_model.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d60b908b8..174ed35e5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -515,6 +515,8 @@ class VAE: def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) @@ -566,7 +568,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy/utils.py b/comfy/utils.py index 5c7143fd9..09e05d4ed 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,4 +1,5 @@ import torch +import math def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -63,10 +64,7 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): - it_1 = -(height // -(tile_y * 2 - overlap)) * -(width // -(tile_x // 2 - overlap)) - it_2 = -(height // -(tile_y // 2 - overlap)) * -(width // -(tile_x * 2 - overlap)) - it_3 = -(height // -(tile_y - overlap)) * -(width // -(tile_x - overlap)) - return it_1 + it_2 + it_3 + return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index f774b4b77..ab5b0ccfc 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -40,7 +40,7 @@ class ImageUpscaleWithModel: tile = 128 + 64 overlap = 8 - steps = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() From 7e51bbd07f809555cc50c4fdae3ef84720e5c86f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 4 May 2023 19:42:07 +0100 Subject: [PATCH 128/190] automatic calculation of image pos from widgets --- web/scripts/app.js | 39 ++++++++++++++++++++++++++++++--------- web/scripts/widgets.js | 9 +-------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index ada1708dc..f0c0f9de4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -263,6 +263,34 @@ export class ComfyApp { */ #addDrawBackgroundHandler(node) { const app = this; + + function getImageTop(node) { + let shiftY; + if (node.imageOffset != null) { + shiftY = node.imageOffset; + } else { + if (node.widgets?.length) { + const w = node.widgets[node.widgets.length - 1]; + shiftY = w.last_y; + if (w.computeSize) { + shiftY += w.computeSize()[1] + 4; + } else { + shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } else { + shiftY = node.computeSize()[1]; + } + } + return shiftY; + } + + node.prototype.setSizeForImage = function () { + const minHeight = getImageTop(this) + 220; + if (this.size[1] < minHeight) { + this.setSize([this.size[0], minHeight]); + } + }; + node.prototype.onDrawBackground = function (ctx) { if (!this.flags.collapsed) { const output = app.nodeOutputs[this.id + ""]; @@ -283,9 +311,7 @@ export class ComfyApp { ).then((imgs) => { if (this.images === output.images) { this.imgs = imgs.filter(Boolean); - if (this.size[1] < 100) { - this.size[1] = 250; - } + this.setSizeForImage?.(); app.graph.setDirtyCanvas(true); } }); @@ -310,12 +336,7 @@ export class ComfyApp { this.imageIndex = imageIndex = 0; } - let shiftY; - if (this.imageOffset != null) { - shiftY = this.imageOffset; - } else { - shiftY = this.computeSize()[1]; - } + const shiftY = getImageTop(this); let dw = this.size[0]; let dh = this.size[1]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index c0e73ffa1..cd471bc93 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -261,20 +261,13 @@ export const ComfyWidgets = { let uploadWidget; function showImage(name) { - // Position the image somewhere sensible - if (!node.imageOffset) { - node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75; - } - const img = new Image(); img.onload = () => { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; img.src = `/view?filename=${name}&type=input`; - if ((node.size[1] - node.imageOffset) < 100) { - node.size[1] = 250 + node.imageOffset; - } + node.setSizeForImage?.(); } // Add our own callback to the combo widget to render an image when it changes From bae4fb4a9dc944c10cca922dc4442eef57bbf583 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 May 2023 18:07:41 -0400 Subject: [PATCH 129/190] Fix imports. --- comfy/cldm/cldm.py | 10 +++---- comfy/gligen.py | 2 +- comfy/ldm/models/autoencoder.py | 8 +++--- comfy/ldm/models/diffusion/ddim.py | 2 +- comfy/ldm/models/diffusion/ddpm.py | 12 ++++----- comfy/ldm/modules/attention.py | 4 +-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- .../modules/diffusionmodules/openaimodel.py | 6 ++--- .../ldm/modules/diffusionmodules/upscaling.py | 4 +-- comfy/ldm/modules/diffusionmodules/util.py | 2 +- .../ldm/modules/encoders/noise_aug_modules.py | 4 +-- comfy/model_management.py | 2 +- comfy/sd.py | 26 +++++++++---------- 13 files changed, 42 insertions(+), 42 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index c60abf80b..cb660ee77 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -5,17 +5,17 @@ import torch import torch as th import torch.nn as nn -from ldm.modules.diffusionmodules.util import ( +from ..ldm.modules.diffusionmodules.util import ( conv_nd, linear, zero_module, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer -from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.util import log_txt_as_img, exists, instantiate_from_config +from ..ldm.modules.attention import SpatialTransformer +from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ..ldm.models.diffusion.ddpm import LatentDiffusion +from ..ldm.util import log_txt_as_img, exists, instantiate_from_config class ControlledUnetModel(UNetModel): diff --git a/comfy/gligen.py b/comfy/gligen.py index 8770383e5..45b674503 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,6 +1,6 @@ import torch from torch import nn, einsum -from ldm.modules.attention import CrossAttention +from .ldm.modules.attention import CrossAttention from inspect import isfunction diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index bd698621c..1fb7ed879 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -3,11 +3,11 @@ import torch import torch.nn.functional as F from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder +from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution -from ldm.util import instantiate_from_config -from ldm.modules.ema import LitEma +from comfy.ldm.util import instantiate_from_config +from comfy.ldm.modules.ema import LitEma # class AutoencoderKL(pl.LightningModule): class AutoencoderKL(torch.nn.Module): diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index deab76f21..c279f2c18 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -4,7 +4,7 @@ import torch import numpy as np from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor +from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor class DDIMSampler(object): diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index d3f0eb2b2..0f484a7f1 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -19,12 +19,12 @@ from tqdm import tqdm from torchvision.utils import make_grid # from pytorch_lightning.utilities.distributed import rank_zero_only -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler +from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from comfy.ldm.modules.ema import LitEma +from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ..autoencoder import IdentityFirstStage, AutoencoderKL +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from .ddim import DDIMSampler __conditioning_keys__ = {'concat': 'c_concat', diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ce7180d91..5eabecd65 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -6,7 +6,7 @@ from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from ldm.modules.diffusionmodules.util import checkpoint +from .diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -21,7 +21,7 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") -from cli_args import args +from comfy.cli_args import args def exists(val): return val is not None diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1599d386e..5e4d2b60f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -6,7 +6,7 @@ import numpy as np from einops import rearrange from typing import Optional, Any -from ldm.modules.attention import MemoryEfficientCrossAttention +from ..attention import MemoryEfficientCrossAttention from comfy import model_management if model_management.xformers_enabled_vae(): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 25309dbd7..4352b756d 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -6,7 +6,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F -from ldm.modules.diffusionmodules.util import ( +from .util import ( checkpoint, conv_nd, linear, @@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import ( normalization, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer -from ldm.util import exists +from ..attention import SpatialTransformer +from comfy.ldm.util import exists # dummy replace diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 038166620..709a7f52e 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -3,8 +3,8 @@ import torch.nn as nn import numpy as np from functools import partial -from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule -from ldm.util import default +from .util import extract_into_tensor, make_beta_schedule +from comfy.ldm.util import default class AbstractLowScaleModel(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index daf35da7b..82ea3f0a6 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,7 +15,7 @@ import torch.nn as nn import numpy as np from einops import repeat -from ldm.util import instantiate_from_config +from comfy.ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index f99e7920a..b59bf204b 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -1,5 +1,5 @@ -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation -from ldm.modules.diffusionmodules.openaimodel import Timestep +from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ..diffusionmodules.openaimodel import Timestep import torch class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): diff --git a/comfy/model_management.py b/comfy/model_management.py index db5d368e1..e89f80d69 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,6 @@ import psutil from enum import Enum -from cli_args import args +from .cli_args import args class VRAMState(Enum): CPU = 0 diff --git a/comfy/sd.py b/comfy/sd.py index 174ed35e5..3543bdb77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,8 +2,8 @@ import torch import contextlib import copy -import sd1_clip -import sd2_clip +from . import sd1_clip +from . import sd2_clip from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL @@ -446,10 +446,10 @@ class CLIP: else: params = {} - if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder": + if self.target_clip.endswith("FrozenOpenCLIPEmbedder"): clip = sd2_clip.SD2ClipModel tokenizer = sd2_clip.SD2Tokenizer - elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder": + elif self.target_clip.endswith("FrozenCLIPEmbedder"): clip = sd1_clip.SD1ClipModel tokenizer = sd1_clip.SD1Tokenizer @@ -896,9 +896,9 @@ def load_clip(ckpt_path, embedding_directory=None): clip_data = utils.load_torch_file(ckpt_path) config = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: - config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip = CLIP(config=config, embedding_directory=embedding_directory) clip.load_from_state_dict(clip_data) return clip @@ -974,9 +974,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clip: clip_config = {} if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: - clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: - clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip = CLIP(config=clip_config, embedding_directory=embedding_directory) w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] @@ -997,7 +997,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" params["noise_schedule_config"] = noise_schedule_config - noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" + noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" if size == 1280: #h params["timestep_dim"] = 1024 elif size == 1024: #l @@ -1049,19 +1049,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] - sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} - model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} + model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' - model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" + model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None - model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" else: sd_config["conditioning_key"] = "crossattn" From 1a31020081b22cb55e573f65a11bd4c2c96f17f1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 00:16:57 -0400 Subject: [PATCH 130/190] Support softsign hypernetwork. --- comfy_extras/nodes_hypernetwork.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 0c7250e43..c19b5e4c7 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -18,6 +18,7 @@ def load_hypernetwork_patch(path, strength): "swish": torch.nn.Hardswish, "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, + "softsign": torch.nn.Softsign, } if activation_func not in valid_activation: From 6ee11d7bc00bdbc109e3b84231aa74fc1799d543 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 00:19:35 -0400 Subject: [PATCH 131/190] Fix import. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e89f80d69..3aea7ea8e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,6 @@ import psutil from enum import Enum -from .cli_args import args +from comfy.cli_args import args class VRAMState(Enum): CPU = 0 From af9cc1fb6a88e604700d3f57638ab23b9f607e9e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 01:28:48 -0400 Subject: [PATCH 132/190] Search recursively in subfolders for embeddings. --- comfy/sd1_clip.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 7f1217c3d..b1a392736 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -191,11 +191,20 @@ def safe_load_embed_zip(embed_path): del embed return out +def expand_directory_list(directories): + dirs = set() + for x in directories: + dirs.add(x) + for root, subdir, file in os.walk(x, followlinks=True): + dirs.add(root) + return list(dirs) def load_embed(embedding_name, embedding_directory): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] + embedding_directory = expand_directory_list(embedding_directory) + valid_file = None for embed_dir in embedding_directory: embed_path = os.path.join(embed_dir, embedding_name) From f31e31ee0a3d7da01f2b1f3b68047445c16e494a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 5 May 2023 10:12:06 +0100 Subject: [PATCH 133/190] Fix box shape Match card to litegraph selection --- web/scripts/app.js | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index ada1708dc..68eeb6329 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -703,7 +703,7 @@ export class ComfyApp { ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) - ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); + ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed)) ctx.roundRect( -6, @@ -715,12 +715,11 @@ export class ComfyApp { else if (shape == LiteGraph.CARD_SHAPE) ctx.roundRect( -6, - -6 + LiteGraph.NODE_TITLE_HEIGHT, + -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, - this.round_radius * 2, - 2 - ); + [this.round_radius * 2,2,this.round_radius * 2,2] + ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); ctx.strokeStyle = color; From de4623a8a4b8282f2d29d5a3ecbcb9840c3dc7ac Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 5 May 2023 10:34:09 +0100 Subject: [PATCH 134/190] actually fix card --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 68eeb6329..98c0e0799 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -718,7 +718,7 @@ export class ComfyApp { -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, - [this.round_radius * 2,2,this.round_radius * 2,2] + [this.round_radius * 2, this.round_radius * 2, 2, 2] ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); From cb1551b819ecaa7d9044c13d0c8e8cfa4ff72830 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 18:01:21 -0400 Subject: [PATCH 135/190] Lowvram mode for gligen and fix some lowvram issues. --- comfy/gligen.py | 27 +++++++++++++++---- comfy/ldm/modules/attention.py | 3 --- .../modules/diffusionmodules/openaimodel.py | 19 ++++++++++--- comfy/model_management.py | 3 +++ 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/comfy/gligen.py b/comfy/gligen.py index 45b674503..8c7cb432e 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -242,14 +242,28 @@ class Gligen(nn.Module): self.position_net = position_net self.key_dim = key_dim self.max_objs = 30 + self.lowvram = False def _set_position(self, boxes, masks, positive_embeddings): + if self.lowvram == True: + self.position_net.to(boxes.device) + objs = self.position_net(boxes, masks, positive_embeddings) - def func(key, x): - module = self.module_list[key] - return module(x, objs) - return func + if self.lowvram == True: + self.position_net.cpu() + def func_lowvram(key, x): + module = self.module_list[key] + module.to(x.device) + r = module(x, objs) + module.cpu() + return r + return func_lowvram + else: + def func(key, x): + module = self.module_list[key] + return module(x, objs) + return func def set_position(self, latent_image_shape, position_params, device): batch, c, h, w = latent_image_shape @@ -294,8 +308,11 @@ class Gligen(nn.Module): masks.to(device), conds.to(device)) + def set_lowvram(self, value=True): + self.lowvram = value + def cleanup(self): - pass + self.lowvram = False def get_models(self): return [self] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5eabecd65..573f4e1c6 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module): x += n x = self.ff(self.norm3(x)) + x - - if current_index is not None: - transformer_options["current_index"] += 1 return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4352b756d..5aef23f33 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): x = layer(x) return x +#This is needed because accelerate makes a copy of transformer_options which breaks "current_index" +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): + for layer in ts: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context, transformer_options) + transformer_options["current_index"] += 1 + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) + else: + x = layer(x) + return x class Upsample(nn.Module): """ @@ -805,13 +818,13 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): - h = module(h, emb, context, transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context, transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -828,7 +841,7 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None - h = module(h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3aea7ea8e..7070912df 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -201,6 +201,9 @@ def load_controlnet_gpu(control_models): return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: + for m in control_models: + if hasattr(m, 'set_lowvram'): + m.set_lowvram(True) #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return From 7a9268185cb6456890f6fe61bcc380b5cb21f614 Mon Sep 17 00:00:00 2001 From: WAS Date: Fri, 5 May 2023 18:06:54 -0700 Subject: [PATCH 136/190] Update README.md Add quick search explanation --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3b3824714..bfa8904df 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Q | Toggle visibility of the queue | | H | Toggle visibility of history | | R | Refresh graph | +| Double-Click LMB | Open node quick search palette | Ctrl can also be replaced with Cmd instead for MacOS users From 8e03c789a25470a88aa05bcc73b1fe226334926b Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 6 May 2023 16:59:40 -0400 Subject: [PATCH 137/190] auto-launch cli arg --- comfy/cli_args.py | 4 ++++ main.py | 13 +++---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 764427165..cc4709f70 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -7,6 +7,7 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port. parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") @@ -30,3 +31,6 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") args = parser.parse_args() + +if args.windows_standalone_build: + args.auto_launch = True diff --git a/main.py b/main.py index f369b82f3..eb97a2fb8 100644 --- a/main.py +++ b/main.py @@ -91,23 +91,16 @@ if __name__ == "__main__": threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - address = args.listen - - dont_print = args.dont_print_server - - if args.output_directory: output_dir = os.path.abspath(args.output_directory) print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) - port = args.port - if args.quick_test_for_ci: exit(0) call_on_start = None - if args.windows_standalone_build: + if args.auto_launch: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port)) @@ -115,10 +108,10 @@ if __name__ == "__main__": if os.name == "nt": try: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) except KeyboardInterrupt: pass else: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) cleanup_temp() From 678f933d382641933920e84414fe36f89d1da5a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 May 2023 19:00:49 -0400 Subject: [PATCH 138/190] maximum_batch_area for xformers. Remove useless code. --- comfy/model_management.py | 7 ++++++- nodes.py | 4 +--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7070912df..b0640d674 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -312,7 +312,12 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - area = ((memory_free - 1024) * 0.9) / (0.6) + if xformers_enabled(): + #TODO: this needs to be tweaked + area = 50 * memory_free + else: + #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future + area = ((memory_free - 1024) * 0.9) / (0.6) return int(max(area, 0)) def cpu_mode(): diff --git a/nodes.py b/nodes.py index c2bc36855..ca0769ba7 100644 --- a/nodes.py +++ b/nodes.py @@ -105,15 +105,13 @@ class ConditioningSetArea: CATEGORY = "conditioning" - def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, width, height, x, y, strength): c = [] for t in conditioning: n = [t[0], t[1].copy()] n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['strength'] = strength n[1]['set_area_to_bounds'] = False - n[1]['min_sigma'] = min_sigma - n[1]['max_sigma'] = max_sigma c.append(n) return (c, ) From 6fc4917634d457c07eb8b676da4fa88e0ef4704b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 May 2023 19:58:54 -0400 Subject: [PATCH 139/190] Make maximum_batch_area take into account python2.0 attention function. More conservative xformers maximum_batch_area. --- comfy/model_management.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b0640d674..39df8d9a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -275,8 +275,17 @@ def xformers_enabled_vae(): return XFORMERS_ENABLED_VAE def pytorch_attention_enabled(): + global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION +def pytorch_attention_flash_attention(): + global ENABLE_PYTORCH_ATTENTION + if ENABLE_PYTORCH_ATTENTION: + #TODO: more reliable way of checking for flash attention? + if torch.version.cuda: #pytorch flash attention only works on Nvidia + return True + return False + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled @@ -312,9 +321,9 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - if xformers_enabled(): + if xformers_enabled() or pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - area = 50 * memory_free + area = 20 * memory_free else: #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future area = ((memory_free - 1024) * 0.9) / (0.6) From ae08fdb9990956f671d658aaf72a1eaf982b5b33 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 9 May 2023 03:37:36 +0900 Subject: [PATCH 140/190] Clipspace Menu and MaskEditor application. (#548) * Add clipspace feature. * feat: copy content to clipspace * feat: paste content from clipspace Extend validation to allow for validating annotated_path in addition to other parameters. Add support for annotated_filepath in folder_paths function. Generalize the '/upload/image' API to allow for uploading images to the 'input', 'temp', or 'output' directories. * rename contentClipboard -> clipspace * Do deep copy for imgs on copy to clipspace. * mask painting on clipspace * add original_imgs into clipspace * Preserve the original image when 'imgs' are modified * robust patch & refactoring folder_paths about annotated_filepath * wip * Only show the Paste menu if the ComfyApp.clipspace is not empty * clipspace feature added maskeditor feature added * instant refresh on paste force triggering 'changed' on paste action * enhance mask painting smooth drawing add brush_size +/- button * robust patch use mouseup event * robust patch again... * subfolder fix on paste logic attach subfolder if subfolder isn't empty * event listener patch add ], [ key event for brush size remove listener on close * Fix button positioning issue related to window height. Change brush size from button to slider. * clean commit * clean code * various bug fixes * paste action - prevent opening upload popup - ensure rendering after widget_value update * view api update - support annotated_filepath * maskeditor layout - prevent covering button by hidden div * remove dbg message * Add cursor functionality to display brush size * refactor: Replace brush preview feature with missionfloyd implementation * missionfloyd implementation * hiding brush preview off the canvas * change brush size on wheel event * keyup -> keydown event * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Add support for channel-specific image data retrieval in /view API to fix alpha mask loading issue When loading an image with an alpha mask in JavaScript canvas, there is an issue where the alpha and RGB channels are premultiplied. To avoid reliance on JavaScript canvas, I added support for channel-specific image data retrieval in the "/view" API. This allows us to retrieve data for each channel separately and fix the alpha mask loading issue. The changes have been committed to the repository. * Enable brush preview for key and slider events * optimize * preview fix * robust patch * fix copy (clipspace) action imgs[0] copy -> whole imgs copy * support batch images on clipspace, maskeditor * copy/paste bug fixes for batch images enhance selector preview on clipspace menu add img_paste_mode option into clipspace menu * crash fix * print message if clipspace content cannot editable * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * make default img_paste_mode to 'selected' refactor space -> tab * save clipspace files to input/clipspace instead of temp * show "clipspace/filename.png" instead of 'filename.png [clipspace]' in LoadImage/LoadImageMask * refresh fix related to FILE_COMBO * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * adjust margin based on missionfloyd impelements * mouse event -> pointer event * pen, touch, mouse drawing patched and tested * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * add comment about touch event. --------- Co-authored-by: Lt.Dr.Data Co-authored-by: missionfloyd --- folder_paths.py | 9 + nodes.py | 8 +- server.py | 122 ++++++- web/extensions/core/clipspace.js | 166 +++++++++ web/extensions/core/maskeditor.js | 589 ++++++++++++++++++++++++++++++ web/scripts/app.js | 114 ++++-- web/scripts/ui.js | 1 + web/scripts/widgets.js | 14 + 8 files changed, 976 insertions(+), 47 deletions(-) create mode 100644 web/extensions/core/clipspace.js create mode 100644 web/extensions/core/maskeditor.js diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..0acd22674 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,6 +57,10 @@ def get_input_directory(): global input_directory return input_directory +def get_clipspace_directory(): + global input_directory + return input_directory+"/clipspace" + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -66,6 +70,8 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() + if type_name == "clipspace": + return get_clipspace_directory() return None @@ -81,6 +87,9 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] + elif name.endswith("[clipspace]"): + base_dir = get_clipspace_directory() + name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index ca0769ba7..1d9a5c872 100644 --- a/nodes.py +++ b/nodes.py @@ -973,8 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), )}, + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, } CATEGORY = "image" @@ -1014,9 +1015,10 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), ), - "channel": (s._color_channels, ),} + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + "channel": (s._color_channels, ), } } CATEGORY = "mask" diff --git a/server.py b/server.py index 1c5c17916..48644d83a 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,9 @@ import execution import uuid import json import glob +from PIL import Image +from io import BytesIO + try: import aiohttp from aiohttp import web @@ -110,19 +113,26 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) + def get_dir_by_type(dir_type): + if dir_type is None: + type_dir = folder_paths.get_input_directory() + elif dir_type == "input": + type_dir = folder_paths.get_input_directory() + elif dir_type == "clipspace": + type_dir = folder_paths.get_clipspace_directory() + elif dir_type == "temp": + type_dir = folder_paths.get_temp_directory() + elif dir_type == "output": + type_dir = folder_paths.get_output_directory() + + return type_dir + @routes.post("/upload/image") async def upload_image(request): post = await request.post() image = post.get("image") - if post.get("type") is None: - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "input": - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "temp": - upload_dir = folder_paths.get_temp_directory() - elif post.get("type") == "output": - upload_dir = folder_paths.get_output_directory() + upload_dir = get_dir_by_type(post.get("type")) if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -147,12 +157,62 @@ class PromptServer(): else: return web.Response(status=400) + @routes.post("/upload/mask") + async def upload_mask(request): + post = await request.post() + image = post.get("image") + original_image = post.get("original_image") + + upload_dir = get_dir_by_type(post.get("type")) + + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + + if image and image.file: + filename = image.filename + if not filename: + return web.Response(status=400) + + split = os.path.splitext(filename) + i = 1 + while os.path.exists(os.path.join(upload_dir, filename)): + filename = f"{split[0]} ({i}){split[1]}" + i += 1 + + filepath = os.path.join(upload_dir, filename) + + original_pil = Image.open(original_image.file).convert('RGBA') + mask_pil = Image.open(image.file).convert('RGBA') + + # alpha copy + new_alpha = mask_pil.getchannel('A') + original_pil.putalpha(new_alpha) + + original_pil.save(filepath) + + return web.json_response({"name": filename}) + else: + return web.Response(status=400) + @routes.get("/view") async def view_image(request): if "filename" in request.rel_url.query: - type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + filename = request.rel_url.query["filename"] + filename,output_dir = folder_paths.annotated_filepath(filename) + + if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): + output_dir = folder_paths.get_clipspace_directory() + filename = filename[10:] + + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) + + if output_dir is None: + type = request.rel_url.query.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) @@ -162,13 +222,49 @@ class PromptServer(): return web.Response(status=403) output_dir = full_output_dir - filename = request.rel_url.query["filename"] filename = os.path.basename(filename) file = os.path.join(output_dir, filename) if os.path.isfile(file): - return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) - + if 'channel' not in request.rel_url.query: + channel = 'rgba' + else: + channel = request.rel_url.query["channel"] + + if channel == 'rgb': + with Image.open(file) as img: + if img.mode == "RGBA": + r, g, b, a = img.split() + new_img = Image.merge('RGB', (r, g, b)) + else: + new_img = img.convert("RGB") + + buffer = BytesIO() + new_img.save(buffer, format='PNG') + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + + elif channel == 'a': + with Image.open(file) as img: + if img.mode == "RGBA": + _, _, _, a = img.split() + else: + a = Image.new('L', img.size, 255) + + # alpha img + alpha_img = Image.new('RGBA', img.size) + alpha_img.putalpha(a) + alpha_buffer = BytesIO() + alpha_img.save(alpha_buffer, format='PNG') + alpha_buffer.seek(0) + + return web.Response(body=alpha_buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + else: + return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) + return web.Response(status=404) @routes.get("/prompt") diff --git a/web/extensions/core/clipspace.js b/web/extensions/core/clipspace.js new file mode 100644 index 000000000..adb5877ea --- /dev/null +++ b/web/extensions/core/clipspace.js @@ -0,0 +1,166 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; + +export class ClipspaceDialog extends ComfyDialog { + static items = []; + static instance = null; + + static registerButton(name, contextPredicate, callback) { + const item = + $el("button", { + type: "button", + textContent: name, + contextPredicate: contextPredicate, + onclick: callback + }) + + ClipspaceDialog.items.push(item); + } + + static invalidatePreview() { + if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) { + const img_preview = document.getElementById("clipspace_preview"); + if(img_preview) { + img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + img_preview.style.maxHeight = "100%"; + img_preview.style.maxWidth = "100%"; + } + } + } + + static invalidate() { + if(ClipspaceDialog.instance) { + const self = ClipspaceDialog.instance; + // allow reconstruct controls when copying from non-image to image content. + const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]); + + if(self.element) { + // update + self.element.removeChild(self.element.firstChild); + self.element.appendChild(children); + } + else { + // new + self.element = $el("div.comfy-modal", { parent: document.body }, [children,]); + } + + if(self.element.children[0].children.length <= 1) { + self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."])); + } + + ClipspaceDialog.invalidatePreview(); + } + } + + constructor() { + super(); + } + + createButtons(self) { + const buttons = []; + + for(let idx in ClipspaceDialog.items) { + const item = ClipspaceDialog.items[idx]; + if(!item.contextPredicate || item.contextPredicate()) + buttons.push(ClipspaceDialog.items[idx]); + } + + buttons.push( + $el("button", { + type: "button", + textContent: "Close", + onclick: () => { this.close(); } + }) + ); + + return buttons; + } + + createImgSettings() { + if(ComfyApp.clipspace.imgs) { + const combo_items = []; + const imgs = ComfyApp.clipspace.imgs; + + for(let i=0; i < imgs.length; i++) { + combo_items.push($el("option", {value:i}, [`${i}`])); + } + + const combo1 = $el("select", + {id:"clipspace_img_selector", onchange:(event) => { + ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex; + ClipspaceDialog.invalidatePreview(); + } }, combo_items); + + const row1 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]), + $el("td", {}, [combo1]) + ]); + + + const combo2 = $el("select", + {id:"clipspace_img_paste_mode", onchange:(event) => { + ComfyApp.clipspace['img_paste_mode'] = event.target.value; + } }, + [ + $el("option", {value:'selected'}, 'selected'), + $el("option", {value:'all'}, 'all') + ]); + combo2.value = ComfyApp.clipspace['img_paste_mode']; + + const row2 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]), + $el("td", {}, [combo2]) + ]); + + const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'}, + [ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]); + + const row3 = + $el("tr", {}, [td]); + + return $el("table", {}, [row1, row2, row3]); + } + else { + return []; + } + } + + createImgPreview() { + if(ComfyApp.clipspace.imgs) { + return $el("img",{id:"clipspace_preview", ondragstart:() => false}); + } + else + return []; + } + + show() { + const img_preview = document.getElementById("clipspace_preview"); + ClipspaceDialog.invalidate(); + + this.element.style.display = "block"; + } +} + +app.registerExtension({ + name: "Comfy.Clipspace", + init(app) { + app.openClipspace = + function () { + if(!ClipspaceDialog.instance) { + ClipspaceDialog.instance = new ClipspaceDialog(app); + ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate; + } + + if(ComfyApp.clipspace) { + ClipspaceDialog.instance.show(); + } + else + app.ui.dialog.show("Clipspace is Empty!"); + }; + } +}); \ No newline at end of file diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js new file mode 100644 index 000000000..c55f841b6 --- /dev/null +++ b/web/extensions/core/maskeditor.js @@ -0,0 +1,589 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; +import { ClipspaceDialog } from "/extensions/core/clipspace.js"; + +// Helper function to convert a data URL to a Blob object +function dataURLToBlob(dataURL) { + const parts = dataURL.split(';base64,'); + const contentType = parts[0].split(':')[1]; + const byteString = atob(parts[1]); + const arrayBuffer = new ArrayBuffer(byteString.length); + const uint8Array = new Uint8Array(arrayBuffer); + for (let i = 0; i < byteString.length; i++) { + uint8Array[i] = byteString.charCodeAt(i); + } + return new Blob([arrayBuffer], { type: contentType }); +} + +function loadedImageToBlob(image) { + const canvas = document.createElement('canvas'); + + canvas.width = image.width; + canvas.height = image.height; + + const ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0); + + const dataURL = canvas.toDataURL('image/png', 1); + const blob = dataURLToBlob(dataURL); + + return blob; +} + +async function uploadMask(filepath, formData) { + await fetch('/upload/mask', { + method: 'POST', + body: formData + }).then(response => {}).catch(error => { + console.error('Error:', error); + }); + + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; + + ClipspaceDialog.invalidatePreview(); +} + +function prepareRGB(image, backupCanvas, backupCtx) { + // paste mask data into alpha channel + backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); +} + +class MaskEditorDialog extends ComfyDialog { + static instance = null; + constructor() { + super(); + this.element = $el("div.comfy-modal", { parent: document.body }, + [ $el("div.comfy-modal-content", + [...this.createButtons()]), + ]); + MaskEditorDialog.instance = this; + } + + createButtons() { + return []; + } + + clearMask(self) { + } + + createButton(name, callback) { + var button = document.createElement("button"); + button.innerText = name; + button.addEventListener("click", callback); + return button; + } + createLeftButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "left"; + button.style.marginRight = "4px"; + return button; + } + createRightButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "right"; + button.style.marginLeft = "4px"; + return button; + } + createLeftSlider(self, name, callback) { + const divElement = document.createElement('div'); + divElement.id = "maskeditor-slider"; + divElement.style.cssFloat = "left"; + divElement.style.fontFamily = "sans-serif"; + divElement.style.marginRight = "4px"; + divElement.style.color = "var(--input-text)"; + divElement.style.backgroundColor = "var(--comfy-input-bg)"; + divElement.style.borderRadius = "8px"; + divElement.style.borderColor = "var(--border-color)"; + divElement.style.borderStyle = "solid"; + divElement.style.fontSize = "15px"; + divElement.style.height = "21px"; + divElement.style.padding = "1px 6px"; + divElement.style.display = "flex"; + divElement.style.position = "relative"; + divElement.style.top = "2px"; + self.brush_slider_input = document.createElement('input'); + self.brush_slider_input.setAttribute('type', 'range'); + self.brush_slider_input.setAttribute('min', '1'); + self.brush_slider_input.setAttribute('max', '100'); + self.brush_slider_input.setAttribute('value', '10'); + const labelElement = document.createElement("label"); + labelElement.textContent = name; + + divElement.appendChild(labelElement); + divElement.appendChild(self.brush_slider_input); + + self.brush_slider_input.addEventListener("change", callback); + + return divElement; + } + + setlayout(imgCanvas, maskCanvas) { + const self = this; + + // If it is specified as relative, using it only as a hidden placeholder for padding is recommended + // to prevent anomalies where it exceeds a certain size and goes outside of the window. + var placeholder = document.createElement("div"); + placeholder.style.position = "relative"; + placeholder.style.height = "50px"; + + var bottom_panel = document.createElement("div"); + bottom_panel.style.position = "absolute"; + bottom_panel.style.bottom = "0px"; + bottom_panel.style.left = "20px"; + bottom_panel.style.right = "20px"; + bottom_panel.style.height = "50px"; + + var brush = document.createElement("div"); + brush.id = "brush"; + brush.style.backgroundColor = "transparent"; + brush.style.outline = "1px dashed black"; + brush.style.boxShadow = "0 0 0 1px white"; + brush.style.borderRadius = "50%"; + brush.style.MozBorderRadius = "50%"; + brush.style.WebkitBorderRadius = "50%"; + brush.style.position = "absolute"; + brush.style.zIndex = 100; + brush.style.pointerEvents = "none"; + this.brush = brush; + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + document.body.appendChild(brush); + + var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + self.brush_size = event.target.value; + self.updateBrushPreview(self, null, null); + }); + var clearButton = this.createLeftButton("Clear", + () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); + }); + var cancelButton = this.createRightButton("Cancel", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.close(); + }); + var saveButton = this.createRightButton("Save", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.save(); + }); + + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + + bottom_panel.appendChild(clearButton); + bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(cancelButton); + bottom_panel.appendChild(brush_size_slider); + + this.element.style.display = "block"; + imgCanvas.style.position = "relative"; + imgCanvas.style.top = "200"; + imgCanvas.style.left = "0"; + + maskCanvas.style.position = "absolute"; + } + + show() { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); + + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; + + this.setlayout(imgCanvas, maskCanvas); + + // prepare content + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); + + this.setImages(imgCanvas, backupCanvas); + this.setEventHandler(maskCanvas); + } + + setImages(imgCanvas, backupCanvas) { + const imgCtx = imgCanvas.getContext('2d'); + const backupCtx = backupCanvas.getContext('2d'); + const maskCtx = this.maskCtx; + const maskCanvas = this.maskCanvas; + + // image load + const orig_image = new Image(); + window.addEventListener("resize", () => { + // repositioning + imgCanvas.width = window.innerWidth - 250; + imgCanvas.height = window.innerHeight - 200; + + // redraw image + let drawWidth = orig_image.width; + let drawHeight = orig_image.height; + if (orig_image.width > imgCanvas.width) { + drawWidth = imgCanvas.width; + drawHeight = (drawWidth / orig_image.width) * orig_image.height; + } + + if (drawHeight > imgCanvas.height) { + drawHeight = imgCanvas.height; + drawWidth = (drawHeight / orig_image.height) * orig_image.width; + } + + imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); + + // update mask + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); + maskCanvas.width = drawWidth; + maskCanvas.height = drawHeight; + maskCanvas.style.top = imgCanvas.offsetTop + "px"; + maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); + }); + + const filepath = ComfyApp.clipspace.images; + + const touched_image = new Image(); + + touched_image.onload = function() { + backupCanvas.width = touched_image.width; + backupCanvas.height = touched_image.height; + + prepareRGB(touched_image, backupCanvas, backupCtx); + }; + + const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) + alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.set('channel', 'a'); + touched_image.src = alpha_url; + + // original image load + orig_image.onload = function() { + window.dispatchEvent(new Event('resize')); + }; + + const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); + rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.set('channel', 'rgb'); + orig_image.src = rgb_url; + this.image = orig_image; + }g + + + setEventHandler(maskCanvas) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + const self = this; + maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + } + + brush_size = 10; + drawing_mode = false; + lastx = -1; + lasty = -1; + lasttime = 0; + + static handleKeyDown(event) { + const self = MaskEditorDialog.instance; + if (event.key === ']') { + self.brush_size = Math.min(self.brush_size+2, 100); + } else if (event.key === '[') { + self.brush_size = Math.max(self.brush_size-2, 1); + } + + self.updateBrushPreview(self); + } + + static handlePointerUp(event) { + event.preventDefault(); + MaskEditorDialog.instance.drawing_mode = false; + } + + updateBrushPreview(self) { + const brush = self.brush; + + var centerX = self.cursorX; + var centerY = self.cursorY; + + brush.style.width = self.brush_size * 2 + "px"; + brush.style.height = self.brush_size * 2 + "px"; + brush.style.left = (centerX - self.brush_size) + "px"; + brush.style.top = (centerY - self.brush_size) + "px"; + } + + handleWheelEvent(self, event) { + if(event.deltaY < 0) + self.brush_size = Math.min(self.brush_size+2, 100); + else + self.brush_size = Math.max(self.brush_size-2, 1); + + self.brush_slider_input.value = self.brush_size; + + self.updateBrushPreview(self); + } + + draw_move(self, event) { + event.preventDefault(); + + this.cursorX = event.pageX; + this.cursorY = event.pageY; + + self.updateBrushPreview(self); + + if (event instanceof TouchEvent || event.buttons == 1) { + var diff = performance.now() - self.lasttime; + + const maskRect = self.maskCanvas.getBoundingClientRect(); + + var x = event.offsetX; + var y = event.offsetY + + if(event.offsetX == null) { + x = event.targetTouches[0].clientX - maskRect.left; + } + + if(event.offsetY == null) { + y = event.targetTouches[0].clientY - maskRect.top; + } + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !this.drawing_mode) + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + } + + handlePointerDown(self, event) { + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + + if ([0, 2, 5].includes(event.button)) { + self.drawing_mode = true; + + event.preventDefault(); + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + self.maskCtx.beginPath(); + if (event.button == 0) { + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + } else { + self.maskCtx.globalCompositeOperation = "destination-out"; + } + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + self.lasttime = performance.now(); + } + } + + save() { + const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.drawImage(this.maskCanvas, + 0, 0, this.maskCanvas.width, this.maskCanvas.height, + 0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // paste mask data into alpha channel + const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); + + const formData = new FormData(); + const filename = "clipspace-mask-" + performance.now() + ".png"; + + const item = + { + "filename": filename, + "subfolder": "", + "type": "clipspace", + }; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[0] = item; + + if(ComfyApp.clipspace.widgets) { + const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + + if(index >= 0) + ComfyApp.clipspace.widgets[index].value = item; + } + + const dataURL = this.backupCanvas.toDataURL(); + const blob = dataURLToBlob(dataURL); + + const original_blob = loadedImageToBlob(this.image); + + formData.append('image', blob, filename); + formData.append('original_image', original_blob); + formData.append('type', "clipspace"); + + uploadMask(item, formData); + this.close(); + } +} + +app.registerExtension({ + name: "Comfy.MaskEditor", + init(app) { + const callback = + function () { + let dlg = new MaskEditorDialog(app); + dlg.show(); + }; + + const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 + ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + } +}); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 245605484..f4f7272db 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -25,6 +25,7 @@ export class ComfyApp { * @type {serialized node object} */ static clipspace = null; + static clipspace_invalidate_handler = null; constructor() { this.ui = new ComfyUI(this); @@ -143,22 +144,34 @@ export class ComfyApp { callback: (obj) => { var widgets = null; if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); } - let img = new Image(); var imgs = undefined; + var orig_imgs = undefined; if(this.imgs != undefined) { - img.src = this.imgs[0].src; - imgs = [img]; + imgs = []; + orig_imgs = []; + + for (let i = 0; i < this.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = this.imgs[i].src; + orig_imgs[i] = imgs[i]; + } } ComfyApp.clipspace = { 'widgets': widgets, 'imgs': imgs, - 'original_imgs': imgs, - 'images': this.images + 'original_imgs': orig_imgs, + 'images': this.images, + 'selectedIndex': 0, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action }; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } } }); @@ -167,48 +180,82 @@ export class ComfyApp { { content: "Paste (Clipspace)", callback: () => { - if(ComfyApp.clipspace != null) { - if(ComfyApp.clipspace.widgets != null && this.widgets != null) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop) { - prop.callback(value); - } - }); - } - + if(ComfyApp.clipspace) { // image paste - if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + if(ComfyApp.clipspace.imgs && this.imgs) { var filename = ""; if(this.images && ComfyApp.clipspace.images) { - this.images = ComfyApp.clipspace.images; + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + + } + else + app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; } - if(ComfyApp.clipspace.images != undefined) { - const clip_image = ComfyApp.clipspace.images[0]; + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + this.imgs = [img]; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); if(index_in_clip >= 0) { - filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + const item = ComfyApp.clipspace.widgets[index_in_clip].value; + if(item.type) + filename = `${item.filename} [${item.type}]`; + else + filename = item.filename; } } - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { - this.imgs = ComfyApp.clipspace.imgs; + // for Load Image node. + if(this.widgets) { + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "") { + const postfix = ' [clipspace]'; + if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { + filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); + } - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } } } } - this.trigger('changed'); + + // ensure render after update widget_value + if(ComfyApp.clipspace.widgets && this.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.callback(value); + } + }); + } } + + app.graph.setDirtyCanvas(true); } } ); @@ -1275,12 +1322,17 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] - if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - widget.options.values = def["input"]["required"][widget.name][0]; + if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { + console.log(widget.options.values = def["input"]["required"][widget.name][1].files); + widget.options.values = def["input"]["required"][widget.name][1].files; + } + else + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; + widget.callback(widget.value); } } } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 5accc9d86..77517aec1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -581,6 +581,7 @@ export class ComfyUI { }), $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }), $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index cd471bc93..4a72246db 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,6 +256,20 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, + FILE_COMBO(node, inputName, inputData) { + const base_dir = inputData[1].base_dir; + let defaultValue = inputData[1].files[0]; + + const files = [] + for(let i in inputData[1].files) { + files[i] = inputData[1].files[i]; + const postfix = ' [clipspace]'; + if(base_dir == 'input' && files[i].endsWith(postfix)) + files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); + } + + return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; + }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; From 850daf0416367ba39d10195540f5b735952f0ee7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 14:13:06 -0400 Subject: [PATCH 141/190] Masked editor changes. Add a way to upload to subfolders. Clean up code. Fix some issues. --- folder_paths.py | 9 ---- nodes.py | 8 ++-- server.py | 74 ++++++++++++------------------- web/extensions/core/maskeditor.js | 9 ++-- web/scripts/app.js | 66 ++++++++------------------- web/scripts/widgets.js | 52 +++++++++++++++------- 6 files changed, 93 insertions(+), 125 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 0acd22674..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,10 +57,6 @@ def get_input_directory(): global input_directory return input_directory -def get_clipspace_directory(): - global input_directory - return input_directory+"/clipspace" - #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -70,8 +66,6 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() - if type_name == "clipspace": - return get_clipspace_directory() return None @@ -87,9 +81,6 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] - elif name.endswith("[clipspace]"): - base_dir = get_clipspace_directory() - name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index 1d9a5c872..699e60ae8 100644 --- a/nodes.py +++ b/nodes.py @@ -973,9 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, + {"image": (sorted(files), )}, } CATEGORY = "image" @@ -1015,9 +1015,9 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + {"image": (sorted(files), ), "channel": (s._color_channels, ), } } diff --git a/server.py b/server.py index 48644d83a..3d02b2f7a 100644 --- a/server.py +++ b/server.py @@ -118,8 +118,6 @@ class PromptServer(): type_dir = folder_paths.get_input_directory() elif dir_type == "input": type_dir = folder_paths.get_input_directory() - elif dir_type == "clipspace": - type_dir = folder_paths.get_clipspace_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": @@ -127,73 +125,63 @@ class PromptServer(): return type_dir - @routes.post("/upload/image") - async def upload_image(request): - post = await request.post() + def image_upload(post, image_save_function=None): image = post.get("image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) + image_upload_type = post.get("type") + upload_dir = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename if not filename: return web.Response(status=400) + subfolder = post.get("subfolder", "") + full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) + + if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir: + return web.Response(status=400) + + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder) + split = os.path.splitext(filename) + filepath = os.path.join(full_output_folder, filename) + i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): + while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" i += 1 - filepath = os.path.join(upload_dir, filename) + if image_save_function is not None: + image_save_function(image, post, filepath) + else: + with open(filepath, "wb") as f: + f.write(image.file.read()) - with open(filepath, "wb") as f: - f.write(image.file.read()) - - return web.json_response({"name" : filename}) + return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) else: return web.Response(status=400) + @routes.post("/upload/image") + async def upload_image(request): + post = await request.post() + return image_upload(post) + @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() - image = post.get("image") - original_image = post.get("original_image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) - - if image and image.file: - filename = image.filename - if not filename: - return web.Response(status=400) - - split = os.path.splitext(filename) - i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): - filename = f"{split[0]} ({i}){split[1]}" - i += 1 - - filepath = os.path.join(upload_dir, filename) - - original_pil = Image.open(original_image.file).convert('RGBA') + def image_save_function(image, post, filepath): + original_pil = Image.open(post.get("original_image").file).convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA') # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) - return web.json_response({"name": filename}) - else: - return web.Response(status=400) - + return image_upload(post, image_save_function) @routes.get("/view") async def view_image(request): @@ -201,10 +189,6 @@ class PromptServer(): filename = request.rel_url.query["filename"] filename,output_dir = folder_paths.annotated_filepath(filename) - if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): - output_dir = folder_paths.get_clipspace_directory() - filename = filename[10:] - # validation for security: prevent accessing arbitrary path if filename[0] == '/' or '..' in filename: return web.Response(status=400) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index c55f841b6..0ffa50c69 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -546,8 +546,8 @@ class MaskEditorDialog extends ComfyDialog { const item = { "filename": filename, - "subfolder": "", - "type": "clipspace", + "subfolder": "clipspace", + "type": "input", }; if(ComfyApp.clipspace.images) @@ -567,7 +567,8 @@ class MaskEditorDialog extends ComfyDialog { formData.append('image', blob, filename); formData.append('original_image', original_blob); - formData.append('type', "clipspace"); + formData.append('type', "input"); + formData.append('subfolder', "clipspace"); uploadMask(item, formData); this.close(); diff --git a/web/scripts/app.js b/web/scripts/app.js index f4f7272db..c6c29e45b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -183,7 +183,6 @@ export class ComfyApp { if(ComfyApp.clipspace) { // image paste if(ComfyApp.clipspace.imgs && this.imgs) { - var filename = ""; if(this.images && ComfyApp.clipspace.images) { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; @@ -209,49 +208,25 @@ export class ComfyApp { } } } - - if(ComfyApp.clipspace.images) { - const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; - if(clip_image.subfolder != '') - filename = `${clip_image.subfolder}/`; - filename += `${clip_image.filename} [${clip_image.type}]`; - } - else if(ComfyApp.clipspace.widgets) { - const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); - if(index_in_clip >= 0) { - const item = ComfyApp.clipspace.widgets[index_in_clip].value; - if(item.type) - filename = `${item.filename} [${item.type}]`; - else - filename = item.filename; - } - } - - // for Load Image node. - if(this.widgets) { - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "") { - const postfix = ' [clipspace]'; - if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { - filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); - } - - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; - } - } - } } - // ensure render after update widget_value - if(ComfyApp.clipspace.widgets && this.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.callback(value); - } - }); + if(this.widgets) { + if(ComfyApp.clipspace.images) { + const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0) { + this.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } } } @@ -1323,12 +1298,7 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { - console.log(widget.options.values = def["input"]["required"][widget.name][1].files); - widget.options.values = def["input"]["required"][widget.name][1].files; - } - else - widget.options.values = def["input"]["required"][widget.name][0]; + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 4a72246db..65edc0392 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,20 +256,6 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, - FILE_COMBO(node, inputName, inputData) { - const base_dir = inputData[1].base_dir; - let defaultValue = inputData[1].files[0]; - - const files = [] - for(let i in inputData[1].files) { - files[i] = inputData[1].files[i]; - const postfix = ' [clipspace]'; - if(base_dir == 'input' && files[i].endsWith(postfix)) - files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); - } - - return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; - }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; @@ -280,10 +266,46 @@ export const ComfyWidgets = { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; - img.src = `/view?filename=${name}&type=input`; + let folder_separator = name.lastIndexOf("/"); + let subfolder = ""; + if (folder_separator > -1) { + subfolder = name.substring(0, folder_separator); + name = name.substring(folder_separator + 1); + } + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; node.setSizeForImage?.(); } + var default_value = imageWidget.value; + Object.defineProperty(imageWidget, "value", { + set : function(value) { + this._real_value = value; + }, + + get : function() { + let value = ""; + if (this._real_value) { + value = this._real_value; + } else { + return default_value; + } + + if (value.filename) { + let real_value = value; + value = ""; + if (real_value.subfolder) { + value = real_value.subfolder + "/"; + } + + value += real_value.filename; + + if(real_value.type && real_value.type !== "input") + value += ` [${real_value.type}]`; + } + return value; + } + }); + // Add our own callback to the combo widget to render an image when it changes const cb = node.callback; imageWidget.callback = function () { From a7ebd5aa1278a63f2f14852dce59b43834f6b9d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 15:52:33 -0400 Subject: [PATCH 142/190] Fix masked editor issue with firefox on windows. --- web/extensions/core/maskeditor.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 0ffa50c69..552059e86 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -368,7 +368,7 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (event instanceof TouchEvent || event.buttons == 1) { + if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -389,7 +389,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. brush_size *= this.last_pressure; } @@ -442,7 +442,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ brush_size *= this.last_pressure; } else { From a8705dbfe20ba86eaac5a669c61453775c796441 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 17:05:28 -0400 Subject: [PATCH 143/190] Speed up the mask save and fix refresh replacing copied image. --- server.py | 2 +- web/scripts/app.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 3d02b2f7a..c1226f304 100644 --- a/server.py +++ b/server.py @@ -179,7 +179,7 @@ class PromptServer(): # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) + original_pil.save(filepath, compress_level=4) return image_upload(post, image_save_function) diff --git a/web/scripts/app.js b/web/scripts/app.js index c6c29e45b..2da1b5581 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1300,7 +1300,7 @@ export class ComfyApp { if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { widget.options.values = def["input"]["required"][widget.name][0]; - if(!widget.options.values.includes(widget.value)) { + if(widget.name != 'image' && !widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; widget.callback(widget.value); } From c6e34963e412e1960f73ad357d10c2b7bd1464e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 18:15:19 -0400 Subject: [PATCH 144/190] Make t2i adapter work with any latent resolution. --- comfy/t2i_adapter/adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 0221fff83..87e3d859e 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -56,7 +56,12 @@ class Downsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels - return self.op(x) + if not self.use_conv: + padding = [x.shape[2] % 2, x.shape[3] % 2] + self.op.padding = padding + + x = self.op(x) + return x class ResnetBlock(nn.Module): From d43e45ce624b82dadbe98646329d2b0fbc17edcf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 10:29:58 -0400 Subject: [PATCH 145/190] Remove print. --- nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nodes.py b/nodes.py index 699e60ae8..760db24e1 100644 --- a/nodes.py +++ b/nodes.py @@ -443,7 +443,6 @@ class ControlNetApply: def apply_controlnet(self, conditioning, control_net, image, strength): c = [] control_hint = image.movedim(-1,1) - print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) From 314e526c5ce428a3717207c5c36a42a5c895b6a5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 12:18:18 -0400 Subject: [PATCH 146/190] Not needed anymore because sampling works with any latent size. --- comfy/samplers.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index dcf93cca2..6417f2ed4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -362,19 +362,8 @@ def resolve_cond_masks(conditions, h, w, device): else: box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) - # Make sure the height and width are divisible by 8 - if X % 8 != 0: - newx = X // 8 * 8 - W = W + (X - newx) - X = newx - if Y % 8 != 0: - newy = Y // 8 * 8 - H = H + (Y - newy) - Y = newy - if H % 8 != 0: - H = H + (8 - (H % 8)) - if W % 8 != 0: - W = W + (8 - (W % 8)) + H = max(8, H) + W = max(8, W) area = (int(H), int(W), int(Y), int(X)) modified['area'] = area From 02ca1c67f87e46e926aba325e73b2845d5244874 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 23:51:52 -0400 Subject: [PATCH 147/190] Don't print traceback when processing interrupted. --- execution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index c19c10bc6..edf884611 100644 --- a/execution.py +++ b/execution.py @@ -194,7 +194,10 @@ class PromptExecutor: if valid: recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: - print(traceback.format_exc()) + if isinstance(e, comfy.model_management.InterruptProcessingException): + print("Processing interrupted") + else: + print(traceback.format_exc()) to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From d6dee8af1df5e7dc80463b9e45bdce76767e4119 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 00:29:31 -0400 Subject: [PATCH 148/190] Only validate each input once. --- execution.py | 40 ++++++++++++++++++---------------------- main.py | 2 +- server.py | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/execution.py b/execution.py index edf884611..3953fde3a 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def execute(self, prompt, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -172,27 +172,15 @@ class PromptExecutor: executed = set() try: to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] + for x in list(execute_outputs): + to_execute += [(0, x)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -219,8 +207,11 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -241,8 +232,9 @@ def validate_inputs(prompt, item): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) + r = validate_inputs(prompt, o_id, validated) if r[0] == False: + validated[o_id] = r return r else: if type_input == "INT": @@ -270,7 +262,10 @@ def validate_inputs(prompt, item): if isinstance(type_input, list): if val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + + ret = (True, "") + validated[unique_id] = ret + return ret def validate_prompt(prompt): outputs = set() @@ -284,11 +279,12 @@ def validate_prompt(prompt): good_outputs = set() errors = [] + validated = {} for o in outputs: valid = False reason = "" try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] except Exception as e: @@ -297,7 +293,7 @@ def validate_prompt(prompt): reason = "Parsing error" if valid == True: - good_outputs.add(x) + good_outputs.add(o) else: print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") @@ -307,7 +303,7 @@ def validate_prompt(prompt): errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - return (True, "") + return (True, "", list(good_outputs)) class PromptQueue: diff --git a/main.py b/main.py index eb97a2fb8..d385df70a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) + e.execute(item[-3], item[-2], item[-1]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): diff --git a/server.py b/server.py index c1226f304..b6ac7d483 100644 --- a/server.py +++ b/server.py @@ -312,7 +312,7 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) else: resp_code = 400 out_string = valid[1] From 8e3d1cbf3b8488b319675f952e1a868aa78f1161 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 01:45:27 -0400 Subject: [PATCH 149/190] Fix bug when uploading image with the same name. --- server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server.py b/server.py index b6ac7d483..911f6a614 100644 --- a/server.py +++ b/server.py @@ -151,6 +151,7 @@ class PromptServer(): i = 1 while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) i += 1 if image_save_function is not None: From 51583164ef08d2173eb93eefa36bc50429cfe7c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 10:03:30 -0400 Subject: [PATCH 150/190] Make MaskToImage support masks with a batch size. --- comfy_extras/nodes_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 131cd6a9f..9916f3b21 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -72,7 +72,7 @@ class MaskToImage: FUNCTION = "mask_to_image" def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: From f7c0f75d1fb1c6e3657f69247eace796882c62da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 13:58:19 -0400 Subject: [PATCH 151/190] Auto batching improvements. Try batching when cond sizes don't match with smart padding. --- comfy/samplers.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6417f2ed4..aa44fa82d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,10 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +import math + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise @@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: - if c1['c_crossattn'].shape != c2['c_crossattn'].shape: - return False + s1 = c1['c_crossattn'].shape + s2 = c2['c_crossattn'].shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False @@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn = [] c_concat = [] c_adm = [] + crossattn_max_len = 0 for x in c_list: if 'c_crossattn' in x: - c_crossattn.append(x['c_crossattn']) + c = x['c_crossattn'] + if crossattn_max_len == 0: + crossattn_max_len = c.shape[1] + else: + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + c_crossattn.append(c) if 'c_concat' in x: c_concat.append(x['c_concat']) if 'c_adm' in x: c_adm.append(x['c_adm']) out = {} - if len(c_crossattn) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn)] + c_crossattn_out = [] + for c in c_crossattn: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + c_crossattn_out.append(c) + + if len(c_crossattn_out) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn_out)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] if len(c_adm) > 0: From 602095f614276dd52fad718c223e0be17d12b11e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:49:49 -0400 Subject: [PATCH 152/190] Send execution_error message on websocket on execution exception. --- execution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 3953fde3a..7ee038975 100644 --- a/execution.py +++ b/execution.py @@ -185,7 +185,11 @@ class PromptExecutor: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") else: - print(traceback.format_exc()) + message = str(traceback.format_exc()) + print(message) + if self.server.client_id is not None: + self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From 3a7c3acc72435f312a8f050d8ad3a1c902d9cff4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:59:24 -0400 Subject: [PATCH 153/190] Send websocket message with list of cached nodes right before execution. --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 7ee038975..7d18d3b65 100644 --- a/execution.py +++ b/execution.py @@ -169,6 +169,8 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) executed = set() try: to_execute = [] From 974958ff81d9af92b01490bcc99dfc93f8bb5d30 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 16:41:43 -0400 Subject: [PATCH 154/190] Make the prompt_id a uuid and return it when queueing the prompt. --- server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 911f6a614..6965ff3c1 100644 --- a/server.py +++ b/server.py @@ -81,7 +81,7 @@ class PromptServer(): # Reusing existing session, remove old self.sockets.pop(sid, None) else: - sid = uuid.uuid4().hex + sid = uuid.uuid4().hex self.sockets[sid] = ws @@ -313,7 +313,9 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) + prompt_id = str(uuid.uuid4()) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + return web.json_response({"prompt_id": prompt_id}) else: resp_code = 400 out_string = valid[1] From dfc74c19d944b4a4503e22297592fa3a537d3092 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 01:22:40 -0400 Subject: [PATCH 155/190] Add the prompt_id to some websocket messages. --- execution.py | 8 ++++---- main.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index 7d18d3b65..0ac4d462c 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}, execute_outputs=[]): + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -170,7 +170,7 @@ class PromptExecutor: current_outputs = set(self.outputs.keys()) if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() try: to_execute = [] @@ -190,7 +190,7 @@ class PromptExecutor: message = str(traceback.format_exc()) print(message) if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) to_delete = [] for o in self.outputs: @@ -207,7 +207,7 @@ class PromptExecutor: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None }, self.server.client_id) + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) gc.collect() comfy.model_management.soft_empty_cache() diff --git a/main.py b/main.py index d385df70a..00cbf3c4a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-3], item[-2], item[-1]) + e.execute(item[2], item[1], item[3], item[4]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): From 8ea165dd1ef877f58f3710f31ce43f27e0f739ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 14:15:13 -0400 Subject: [PATCH 156/190] Add a way to overwrite images when uploading. --- server.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 6965ff3c1..a2bb26ad9 100644 --- a/server.py +++ b/server.py @@ -127,6 +127,7 @@ class PromptServer(): def image_upload(post, image_save_function=None): image = post.get("image") + overwrite = post.get("overwrite") image_upload_type = post.get("type") upload_dir = get_dir_by_type(image_upload_type) @@ -148,11 +149,14 @@ class PromptServer(): split = os.path.splitext(filename) filepath = os.path.join(full_output_folder, filename) - i = 1 - while os.path.exists(filepath): - filename = f"{split[0]} ({i}){split[1]}" - filepath = os.path.join(full_output_folder, filename) - i += 1 + if overwrite is not None and (overwrite == "true" or overwrite == "1"): + pass + else: + i = 1 + while os.path.exists(filepath): + filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) + i += 1 if image_save_function is not None: image_save_function(image, post, filepath) From 8a4ff5e34cc53252a9ff23e796904100d75bea55 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 12 May 2023 20:58:29 +0100 Subject: [PATCH 157/190] allow static files to be symlinks --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index a2bb26ad9..ef858a98a 100644 --- a/server.py +++ b/server.py @@ -362,7 +362,7 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ - web.static('/', self.web_root), + web.static('/', self.web_root, follow_symlinks=True), ]) def get_queue_info(self): From d9e088ddfd97663abbb933c77f79d2a6c6127851 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:49:09 +0200 Subject: [PATCH 158/190] minor changes for tiled sampler --- comfy/ldm/modules/tomesd.py | 2 +- comfy/sd.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 6a13b80c9..bb971e88f 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, """ B, N, _ = metric.shape - if r <= 0: + if r <= 0 or w == 1 or h == 1: return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather diff --git a/comfy/sd.py b/comfy/sd.py index 3543bdb77..0200f7742 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -581,10 +581,7 @@ class VAE: samples = samples.cpu() return samples -def resize_image_to(tensor, target_latent_tensor, batched_number): - tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") - target_batch_size = target_latent_tensor.shape[0] - +def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] print(current_batch_size, target_batch_size) if current_batch_size == 1: @@ -623,7 +620,9 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -794,10 +793,14 @@ class T2IAdapter: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint + self.control_input = None self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint) self.t2i_model.cpu() From 19c014f4292863444a3d677d504ad58623395a58 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:57:40 +0200 Subject: [PATCH 159/190] comment out annoying print statement --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 0200f7742..c6be900ad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -583,7 +583,7 @@ class VAE: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - print(current_batch_size, target_batch_size) + #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor From c5c0ea666f8456b5a788092bad88528bbf34f559 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 12 May 2023 20:34:48 -0400 Subject: [PATCH 160/190] noise_mask in latent should be in a single format. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 760db24e1..c2201dafc 100644 --- a/nodes.py +++ b/nodes.py @@ -795,7 +795,7 @@ class SetLatentNoiseMask: def set_mask(self, samples, mask): s = samples.copy() - s["noise_mask"] = mask + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): From 997dd1b1312a00cbedeafaf916e49f294a73a431 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 02:07:49 -0400 Subject: [PATCH 161/190] Fix queue delete. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index a2bb26ad9..8435d091b 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) + delete_func = lambda a: a[1] == id_to_delete self.prompt_queue.delete_queue_item(delete_func) - + return web.Response(status=200) @routes.post("/interrupt") From 1201d2eae5820bb8124beb22b712d743415fd47d Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 13 May 2023 17:15:45 +0200 Subject: [PATCH 162/190] Make nodes map over input lists (#579) * allow nodes to map over lists * make work with IS_CHANGED and VALIDATE_INPUTS * give list outputs distinct socket shape * add rebatch node * add batch index logic * add repeat latent batch * deal with noise mask edge cases in latentfrombatch --- comfy/sample.py | 17 ++++-- comfy_extras/nodes_rebatch.py | 108 ++++++++++++++++++++++++++++++++++ execution.py | 90 +++++++++++++++++++++++----- nodes.py | 57 +++++++++++++++--- server.py | 1 + web/scripts/app.js | 3 +- 6 files changed, 250 insertions(+), 26 deletions(-) create mode 100644 comfy_extras/nodes_rebatch.py diff --git a/comfy/sample.py b/comfy/sample.py index bd38585ac..284efca61 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..0a9daf272 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,108 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/execution.py b/execution.py index 0ac4d462c..cf2e5ea71 100644 --- a/execution.py +++ b/execution.py @@ -26,20 +26,81 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = prompt + input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": - input_data_all[x] = unique_id + input_data_all[x] = [unique_id] return input_data_all +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + intput_is_list = False + if hasattr(obj, "INPUT_IS_LIST"): + intput_is_list = obj.INPUT_IS_LIST + + max_len_input = max([len(x) for x in input_data_all.values()]) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + d_new = dict() + for k,v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + results = [] + if intput_is_list: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + ui = dict() + if len(uis) > 0: + ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + return output, ui + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -63,13 +124,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute server.send_sync("executing", { "node": unique_id }, server.client_id) obj = class_def() - nodes.before_node_execution() - outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) - if "ui" in outputs[unique_id]: + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) - if "result" in outputs[unique_id]: - outputs[unique_id] = outputs[unique_id]["result"] + server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -261,9 +321,11 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): if val not in type_input: diff --git a/nodes.py b/nodes.py index c2201dafc..509dc0697 100644 --- a/nodes.py +++ b/nodes.py @@ -629,18 +629,57 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() + if "noise_mask" in samples: + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() + if "batch_index" not in s: + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: @@ -805,8 +844,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -1170,6 +1209,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1244,6 +1284,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image", @@ -1299,3 +1341,4 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) diff --git a/server.py b/server.py index 8435d091b..cb66cc618 100644 --- a/server.py +++ b/server.py @@ -268,6 +268,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = x info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x diff --git a/web/scripts/app.js b/web/scripts/app.js index 2da1b5581..1a4a18b94 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -976,7 +976,8 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - this.addOutput(outputName, output); + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); } const s = this.computeSize(); From 44f9f9baf170ddf27891b240002300d8aa09fb2a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:17:16 -0400 Subject: [PATCH 163/190] Add the prompt id to some websocket messages. --- execution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index cf2e5ea71..b9548229c 100644 --- a/execution.py +++ b/execution.py @@ -101,7 +101,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -116,19 +116,19 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = class_def() output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -215,6 +215,9 @@ class PromptExecutor: else: self.server.client_id = None + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] @@ -242,7 +245,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") From cb4b8223981ec9e090ebf44205f5ce16d72f01cb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:54:45 -0400 Subject: [PATCH 164/190] Print custom nodes that take too much time to import. --- nodes.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nodes.py b/nodes.py index 509dc0697..bc7968308 100644 --- a/nodes.py +++ b/nodes.py @@ -6,6 +6,7 @@ import json import hashlib import traceback import math +import time from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -1325,6 +1326,7 @@ def load_custom_node(module_path): def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") + node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: @@ -1333,7 +1335,16 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + time_before = time.time() load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path)) + + slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) + if len(slow_nodes) > 0: + print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + for n in sorted(slow_nodes): + print("{:6.1f} seconds to import:".format(n[0]), n[1]) + print() def init_custom_nodes(): load_custom_nodes() From cf439709b6b3ffae5ad15a9f7e59fedc214d5f1c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 12:50:21 -0400 Subject: [PATCH 165/190] Load nodes in comfy_extras before custom nodes. Change the slow import message. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index bc7968308..956b739d9 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,15 +1341,15 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + print("\nDetected some custom nodes that were slow to import:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() def init_custom_nodes(): - load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_nodes() From 92bf1cb61efcab45961d1119cb7ec7a076caf24e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:05:52 -0400 Subject: [PATCH 166/190] Change message. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 956b739d9..28215127c 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,7 +1341,7 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import:") + print("\nImport times for custom nodes:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From 2ac744f6628d107b3534177eeca5ef06f6668609 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:15:31 -0400 Subject: [PATCH 167/190] Print all custom node import times. --- nodes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 28215127c..f3b7da1a9 100644 --- a/nodes.py +++ b/nodes.py @@ -1339,10 +1339,9 @@ def load_custom_nodes(): load_custom_node(module_path) node_import_times.append((time.time() - time_before, module_path)) - slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) - if len(slow_nodes) > 0: + if len(node_import_times) > 0: print("\nImport times for custom nodes:") - for n in sorted(slow_nodes): + for n in sorted(node_import_times): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From db4d3a8494a4a7dbb6f911ae126a92abec6bf91b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:23:42 -0400 Subject: [PATCH 168/190] Print if custom nodes imported successfully or not. --- nodes.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index f3b7da1a9..63d9adc3d 100644 --- a/nodes.py +++ b/nodes.py @@ -1318,11 +1318,14 @@ def load_custom_node(module_path): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) + return True else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + return False except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + return False def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") @@ -1336,13 +1339,17 @@ def load_custom_nodes(): module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue time_before = time.time() - load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path)) + success = load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") for n in sorted(node_import_times): - print("{:6.1f} seconds to import:".format(n[0]), n[1]) + if n[2]: + import_message = "" + else: + import_message = " (IMPORT FAILED)" + print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) print() def init_custom_nodes(): From b0505eb7ab8af1986dabd97c23fae83a0539303d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 15:31:22 -0400 Subject: [PATCH 169/190] Return right type when none specified in upload route. Switch time.time to time.perf_counter for custom node import times. --- nodes.py | 4 ++-- server.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index 63d9adc3d..c4aff1012 100644 --- a/nodes.py +++ b/nodes.py @@ -1338,9 +1338,9 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - time_before = time.time() + time_before = time.perf_counter() success = load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path, success)) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") diff --git a/server.py b/server.py index d1079dd83..ba4dcba03 100644 --- a/server.py +++ b/server.py @@ -115,22 +115,23 @@ class PromptServer(): def get_dir_by_type(dir_type): if dir_type is None: - type_dir = folder_paths.get_input_directory() - elif dir_type == "input": + dir_type = "input" + + if dir_type == "input": type_dir = folder_paths.get_input_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": type_dir = folder_paths.get_output_directory() - return type_dir + return type_dir, dir_type def image_upload(post, image_save_function=None): image = post.get("image") overwrite = post.get("overwrite") image_upload_type = post.get("type") - upload_dir = get_dir_by_type(image_upload_type) + upload_dir, image_upload_type = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename From 3a1f47764d76bb9878b55e82657044b3faceda9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 17:11:27 -0400 Subject: [PATCH 170/190] Print the torch device that is used on startup. --- comfy/model_management.py | 42 ++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39df8d9a7..c15323219 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -127,6 +127,32 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() + +def get_torch_device_name(device): + if hasattr(device, 'type'): + return "{}".format(device.type) + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + +try: + print("Using device:", get_torch_device_name(get_torch_device())) +except: + print("Could not pick default device.") + current_loaded_model = None current_gpu_controlnets = [] @@ -233,22 +259,6 @@ def unload_if_low_vram(model): return model.cpu() return model -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type From e7b9d2c02cffd59fecca4ee617137ea38641078a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:30:58 -0400 Subject: [PATCH 171/190] /prompt endpoint error is now in json format. --- server.py | 7 +++---- web/scripts/api.js | 2 +- web/scripts/app.js | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index ba4dcba03..f52117f10 100644 --- a/server.py +++ b/server.py @@ -323,12 +323,11 @@ class PromptServer(): self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) return web.json_response({"prompt_id": prompt_id}) else: - resp_code = 400 - out_string = valid[1] print("invalid prompt:", valid[1]) + return web.json_response({"error": valid[1]}, status=400) + else: + return web.json_response({"error": "no prompt"}, status=400) - return web.Response(body=out_string, status=resp_code) - @routes.post("/queue") async def post_queue(request): json_data = await request.json() diff --git a/web/scripts/api.js b/web/scripts/api.js index d29faa5ba..4f061c358 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -163,7 +163,7 @@ class ComfyApi extends EventTarget { if (res.status !== 200) { throw { - response: await res.text(), + response: await res.json(), }; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 1a4a18b94..00d3c9746 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1222,7 +1222,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response || error.toString()); + this.ui.dialog.show(error.response.error || error.toString()); break; } From 9bf67c4c5a5c8b8d1efc2d4ce7e7ab1eccce1fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:34:25 -0400 Subject: [PATCH 172/190] Print prompt execution time. --- execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/execution.py b/execution.py index b9548229c..dd88029bc 100644 --- a/execution.py +++ b/execution.py @@ -6,6 +6,7 @@ import threading import heapq import traceback import gc +import time import torch import nodes @@ -215,6 +216,7 @@ class PromptExecutor: else: self.server.client_id = None + execution_start_time = time.perf_counter() if self.server.client_id is not None: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) @@ -272,6 +274,7 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() comfy.model_management.soft_empty_cache() From d926f65f56217e7828ad27ec5b646c74398593c4 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 14 May 2023 23:21:22 +0900 Subject: [PATCH 173/190] Feature/maskeditor context menu (#649) * add "Open in MaskEditor" to context menu * change save button name to 'Save to node' if open in node. clear clipspace_return_node after auto paste * * leak patch: prevent infinite duplication of MaskEditorDialog instance on every dialog open * prevent conflict of multiple opening of MaskEditorDialog * name of save button fix * patch: brushPreview hiding by dialog * consider close by 'esc' key on maskeditor. * bugfix about last patch * patch: invalid close detection * 'enter' key as save action * * batch support enhance - pick index based on imageIndex on copy action * paste fix on batch image node * typo --------- Co-authored-by: Lt.Dr.Data --- web/extensions/core/maskeditor.js | 120 ++++++++++++---- web/scripts/app.js | 226 +++++++++++++++++------------- 2 files changed, 221 insertions(+), 125 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 552059e86..4b0c12747 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) { class MaskEditorDialog extends ComfyDialog { static instance = null; + + static getInstance() { + if(!MaskEditorDialog.instance) { + MaskEditorDialog.instance = new MaskEditorDialog(app); + } + + return MaskEditorDialog.instance; + } + + is_layout_created = false; + constructor() { super(); this.element = $el("div.comfy-modal", { parent: document.body }, [ $el("div.comfy-modal-content", [...this.createButtons()]), ]); - MaskEditorDialog.instance = this; } createButtons() { return []; } - clearMask(self) { - } - createButton(name, callback) { var button = document.createElement("button"); button.innerText = name; button.addEventListener("click", callback); return button; } + createLeftButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "left"; button.style.marginRight = "4px"; return button; } + createRightButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "right"; button.style.marginLeft = "4px"; return button; } + createLeftSlider(self, name, callback) { const divElement = document.createElement('div'); divElement.id = "maskeditor-slider"; @@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog { brush.style.MozBorderRadius = "50%"; brush.style.WebkitBorderRadius = "50%"; brush.style.position = "absolute"; - brush.style.zIndex = 100; + brush.style.zIndex = 8889; brush.style.pointerEvents = "none"; this.brush = brush; this.element.appendChild(imgCanvas); @@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog { document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.close(); }); - var saveButton = this.createRightButton("Save", () => { + + this.saveButton = this.createRightButton("Save", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.save(); @@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); - bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(brush_size_slider); - this.element.style.display = "block"; imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; @@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog { } show() { - // layout - const imgCanvas = document.createElement('canvas'); - const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); + if(!this.is_layout_created) { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); - imgCanvas.id = "imageCanvas"; - maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; - this.setlayout(imgCanvas, maskCanvas); + this.setlayout(imgCanvas, maskCanvas); - // prepare content - this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + // prepare content + this.imgCanvas = imgCanvas; + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); - this.setImages(imgCanvas, backupCanvas); - this.setEventHandler(maskCanvas); + this.setEventHandler(maskCanvas); + + this.is_layout_created = true; + + // replacement of onClose hook since close is not real close + const self = this; + const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + ComfyApp.onClipspaceEditorClosed(); + } + + self.last_display_style = self.element.style.display; + } + }); + }); + + const config = { attributes: true }; + observer.observe(this.element, config); + } + + this.setImages(this.imgCanvas, this.backupCanvas); + + if(ComfyApp.clipspace_return_node) { + this.saveButton.innerText = "Save to node"; + } + else { + this.saveButton.innerText = "Save"; + } + this.saveButton.disabled = false; + + this.element.style.display = "block"; + this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + } + + isOpened() { + return this.element.style.display == "block"; } setImages(imgCanvas, backupCanvas) { @@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog { const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); + maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); + // image load const orig_image = new Image(); window.addEventListener("resize", () => { @@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog { rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; - }g - + } setEventHandler(maskCanvas) { maskCanvas.addEventListener("contextmenu", (event) => { @@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog { self.brush_size = Math.min(self.brush_size+2, 100); } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + } else if(event.key === 'Enter') { + self.save(); } self.updateBrushPreview(self); @@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog { } } - save() { + async save() { const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); @@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog { formData.append('type', "input"); formData.append('subfolder', "clipspace"); - uploadMask(item, formData); + this.saveButton.innerText = "Saving..."; + this.saveButton.disabled = true; + await uploadMask(item, formData); + ComfyApp.onClipspaceEditorSave(); this.close(); } } @@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog { app.registerExtension({ name: "Comfy.MaskEditor", init(app) { - const callback = + ComfyApp.open_maskeditor = function () { - let dlg = new MaskEditorDialog(app); - dlg.show(); + const dlg = MaskEditorDialog.getInstance(); + if(!dlg.isOpened()) { + dlg.show(); + } }; const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 - ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); } }); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 00d3c9746..87c5e30ca 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -26,6 +26,8 @@ export class ComfyApp { */ static clipspace = null; static clipspace_invalidate_handler = null; + static open_maskeditor = null; + static clipspace_return_node = null; constructor() { this.ui = new ComfyUI(this); @@ -49,6 +51,114 @@ export class ComfyApp { this.shiftDown = false; } + static isImageNode(node) { + return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); + } + + static onClipspaceEditorSave() { + if(ComfyApp.clipspace_return_node) { + ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); + } + } + + static onClipspaceEditorClosed() { + ComfyApp.clipspace_return_node = null; + } + + static copyToClipspace(node) { + var widgets = null; + if(node.widgets) { + widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + var imgs = undefined; + var orig_imgs = undefined; + if(node.imgs != undefined) { + imgs = []; + orig_imgs = []; + + for (let i = 0; i < node.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = node.imgs[i].src; + orig_imgs[i] = imgs[i]; + } + } + + var selectedIndex = 0; + if(node.imageIndex) { + selectedIndex = node.imageIndex; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': orig_imgs, + 'images': node.images, + 'selectedIndex': selectedIndex, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action + }; + + ComfyApp.clipspace_return_node = null; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } + } + + static pasteFromClipspace(node) { + if(ComfyApp.clipspace) { + // image paste + if(ComfyApp.clipspace.imgs && node.imgs) { + if(node.images && ComfyApp.clipspace.images) { + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + } + else + app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + node.imgs = [img]; + node.imageIndex = 0; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); + if(index >= 0) { + node.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } + } + + app.graph.setDirtyCanvas(true); + } + } + /** * Invoke an extension callback * @param {keyof ComfyExtension} method The extension callback to execute @@ -138,102 +248,30 @@ export class ComfyApp { } } - options.push( - { - content: "Copy (Clipspace)", - callback: (obj) => { - var widgets = null; - if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); - } - - var imgs = undefined; - var orig_imgs = undefined; - if(this.imgs != undefined) { - imgs = []; - orig_imgs = []; + // prevent conflict of clipspace content + if(!ComfyApp.clipspace_return_node) { + options.push({ + content: "Copy (Clipspace)", + callback: (obj) => { ComfyApp.copyToClipspace(this); } + }); - for (let i = 0; i < this.imgs.length; i++) { - imgs[i] = new Image(); - imgs[i].src = this.imgs[i].src; - orig_imgs[i] = imgs[i]; + if(ComfyApp.clipspace != null) { + options.push({ + content: "Paste (Clipspace)", + callback: () => { ComfyApp.pasteFromClipspace(this); } + }); + } + + if(ComfyApp.isImageNode(this)) { + options.push({ + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); } - } - - ComfyApp.clipspace = { - 'widgets': widgets, - 'imgs': imgs, - 'original_imgs': orig_imgs, - 'images': this.images, - 'selectedIndex': 0, - 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action - }; - - if(ComfyApp.clipspace_invalidate_handler) { - ComfyApp.clipspace_invalidate_handler(); - } - } - }); - - if(ComfyApp.clipspace != null) { - options.push( - { - content: "Paste (Clipspace)", - callback: () => { - if(ComfyApp.clipspace) { - // image paste - if(ComfyApp.clipspace.imgs && this.imgs) { - if(this.images && ComfyApp.clipspace.images) { - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; - - } - else - app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; - } - - if(ComfyApp.clipspace.imgs) { - // deep-copy to cut link with clipspace - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - const img = new Image(); - img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; - this.imgs = [img]; - } - else { - const imgs = []; - for(let i=0; i obj.name === 'image'); - if(index >= 0) { - this.widgets[index].value = clip_image; - } - } - if(ComfyApp.clipspace.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.value = value; - prop.callback(value); - } - }); - } - } - } - - app.graph.setDirtyCanvas(true); - } - } - ); + }); + } } }; } From acff543d669dba9b03fb500a10010f2da8739ff3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 12:50:21 -0400 Subject: [PATCH 174/190] Remove useless code. --- nodes.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/nodes.py b/nodes.py index c4aff1012..bc23e5c17 100644 --- a/nodes.py +++ b/nodes.py @@ -146,9 +146,6 @@ class ConditioningSetMask: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -161,9 +158,6 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -176,9 +170,6 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -203,9 +194,6 @@ class VAEEncode: return ({"samples":t}, ) class VAEEncodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -220,9 +208,6 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} From 587f89fe5a8e2bcb389fb4919dc33c330320fa41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 15:10:40 -0400 Subject: [PATCH 175/190] Enable safe loading for upscale models. --- comfy_extras/nodes_upscale_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ab5b0ccfc..f9252ea0b 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -17,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) out = model_loading.load_state_dict(sd).eval() return (out, ) From 84ea21c815d426000c233e0c7b8c542764335cc8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 17:02:40 -0400 Subject: [PATCH 176/190] Update litegraph from upstream. --- web/lib/litegraph.core.js | 145 +++++++++++++++++++++++++++++++++++--- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2bc6af0c3..6c81c3ffd 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action) //when clicked on top of a node //and it is not interactive - if (node && this.allow_interaction && !skip_action && !this.read_only) { + if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) { if (!this.live_mode && !node.flags.pinned) { this.bringToFront(node); } //if it wasn't selected? //not dragging mouse to connect two slots - if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { + if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) @@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action) } //double clicking - if (is_double_click && this.selected_nodes[node.id]) { + if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) { //double click node if (node.onDblClick) { node.onDblClick( e, pos, this ); @@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action) this.dirty_canvas = true; } + //get node over + var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); + if (this.dragging_rectangle) { this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; @@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action) this.ds.offset[1] += delta[1] / this.ds.scale; this.dirty_canvas = true; this.dirty_bgcanvas = true; - } else if (this.allow_interaction && !this.read_only) { + } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) { if (this.connecting_node) { this.dirty_canvas = true; } - //get node over - var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); - //remove mouseover flag for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { @@ -9911,7 +9911,7 @@ LGraphNode.prototype.executeAction = function(action) event, active_widget ) { - if (!node.widgets || !node.widgets.length) { + if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) { return null; } @@ -10300,6 +10300,119 @@ LGraphNode.prototype.executeAction = function(action) canvas.graph.add(group); }; + /** + * Determines the furthest nodes in each direction + * @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.getBoundaryNodes = function(nodes) { + let top = null; + let right = null; + let bottom = null; + let left = null; + for (const nID in nodes) { + const node = nodes[nID]; + const [x, y] = node.pos; + const [width, height] = node.size; + + if (top === null || y < top.pos[1]) { + top = node; + } + if (right === null || x + width > right.pos[0] + right.size[0]) { + right = node; + } + if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) { + bottom = node; + } + if (left === null || x < left.pos[0]) { + left = node; + } + } + + return { + "top": top, + "right": right, + "bottom": bottom, + "left": left + }; + } + /** + * Determines the furthest nodes in each direction for the currently selected nodes + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.prototype.boundaryNodesForSelection = function() { + return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes)); + } + + /** + * + * @param {LGraphNode[]} nodes a list of nodes + * @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes + * @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction) + */ + LGraphCanvas.alignNodes = function (nodes, direction, align_to) { + if (!nodes) { + return; + } + + const canvas = LGraphCanvas.active_canvas; + let boundaryNodes = [] + if (align_to === undefined) { + boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes) + } else { + boundaryNodes = { + "top": align_to, + "right": align_to, + "bottom": align_to, + "left": align_to + } + } + + for (const [_, node] of Object.entries(canvas.selected_nodes)) { + switch (direction) { + case "right": + node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0]; + break; + case "left": + node.pos[0] = boundaryNodes["left"].pos[0]; + break; + case "top": + node.pos[1] = boundaryNodes["top"].pos[1]; + break; + case "bottom": + node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1]; + break; + } + } + + canvas.dirty_canvas = true; + canvas.dirty_bgcanvas = true; + }; + + LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node); + } + } + + LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase()); + } + } + LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { var canvas = LGraphCanvas.active_canvas; @@ -12900,6 +13013,14 @@ LGraphNode.prototype.executeAction = function(action) options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); }*/ + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align", + has_submenu: true, + callback: LGraphCanvas.onGroupAlign, + }) + } + if (this._graph_stack && this._graph_stack.length > 0) { options.push(null, { content: "Close subgraph", @@ -13014,6 +13135,14 @@ LGraphNode.prototype.executeAction = function(action) callback: LGraphCanvas.onMenuNodeToSubgraph }); + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align Selected To", + has_submenu: true, + callback: LGraphCanvas.onNodeAlign, + }) + } + options.push(null, { content: "Remove", disabled: !(node.removable !== false && !node.block_delete ), From 1dd846a7bad8cfab679a0976e201c722871c6917 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:27:28 -0400 Subject: [PATCH 177/190] Fix outputs gone from history. --- execution.py | 16 +++++++++++----- main.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index dd88029bc..0e2cc15c1 100644 --- a/execution.py +++ b/execution.py @@ -102,7 +102,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -117,7 +117,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -128,6 +128,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) @@ -205,6 +206,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.outputs_ui = {} self.old_prompt = {} self.server = server @@ -234,6 +236,11 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() @@ -247,7 +254,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -413,8 +420,7 @@ class PromptQueue: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: - if "ui" in outputs[o]: - self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"] + self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 00cbf3c4a..50d3b9a62 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ def prompt_worker(q, server): while True: item, item_id = q.get() e.execute(item[2], item[1], item[3], item[4]) - q.task_done(item_id, e.outputs) + q.task_done(item_id, e.outputs_ui) async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From ef815ba1e24eef45041adec8a55ecd628b20476f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:29:56 -0400 Subject: [PATCH 178/190] Switch default scheduler to normal. --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index aa44fa82d..fccf254ec 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,7 +495,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] From c02a554bcf6ef50f8e252c89dc0a56c08d4955c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:25:24 -0400 Subject: [PATCH 179/190] Make DiffusersLoader work with subfolders. --- nodes.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index bc23e5c17..797ad6c9c 100644 --- a/nodes.py +++ b/nodes.py @@ -282,7 +282,10 @@ class DiffusersLoader: paths = [] for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths += next(os.walk(search_path))[1] + for root, subdir, files in os.walk(search_path, followlinks=True): + if "model_index.json" in files: + paths.append(os.path.relpath(root, start=search_path)) + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -292,9 +295,9 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] - if model_path in paths: - model_path = os.path.join(search_path, model_path) + path = os.path.join(search_path, model_path) + if os.path.exists(path): + model_path = path break return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) From 2ec6d1c6e364ab92e3d8149a83873ac47c797248 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:31:03 -0400 Subject: [PATCH 180/190] Don't import custom nodes when the folder ends with .disabled --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 797ad6c9c..e8b36c24a 100644 --- a/nodes.py +++ b/nodes.py @@ -1326,6 +1326,7 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + if module_path.endswith(".disabled"): continue time_before = time.perf_counter() success = load_custom_node(module_path) node_import_times.append((time.perf_counter() - time_before, module_path, success)) From 5f7968f1fafb2cf5d15fe049fc53265ad0fc6696 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 01:12:44 -0400 Subject: [PATCH 181/190] Print the endpoint ip for localtunnel in the colab notebook. --- notebooks/comfyui_colab.ipynb | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index fecfa6707..c5a209eec 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -175,6 +175,8 @@ "import threading\n", "import time\n", "import socket\n", + "import urllib.request\n", + "\n", "def iframe_thread(port):\n", " while True:\n", " time.sleep(0.5)\n", @@ -183,7 +185,9 @@ " if result == 0:\n", " break\n", " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", + " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", + "\n", + " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " for line in p.stdout:\n", " print(line.decode(), end='')\n", From 13d94caf49b21bd129ec867b04641973e3a102da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 03:18:11 -0400 Subject: [PATCH 182/190] Add control_after_generate to combo primitive. --- web/extensions/core/widgetInputs.js | 2 +- web/scripts/widgets.js | 80 +++++++++++++++++++---------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..4fe0a6013 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -300,7 +300,7 @@ app.registerExtension({ } } - if (widget.type === "number") { + if (widget.type === "number" || widget.type === "combo") { addValueControlWidget(this, widget, "fixed"); } diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 65edc0392..3d1acc53e 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,35 +19,61 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - let min = targetWidget.options.min; - let max = targetWidget.options.max; - // limit to something that javascript can handle - max = Math.min(1125899906842624, max); - min = Math.max(-1125899906842624, min); - let range = (max - min) / (targetWidget.options.step / 10); + console.log(targetWidget); + if (targetWidget.type == "combo" && v !== "fixed") { + let current_index = targetWidget.options.values.indexOf(targetWidget.value); + let current_length = targetWidget.options.values.length; - //adjust values based on valueControl Behaviour - switch (v) { - case "fixed": - break; - case "increment": - targetWidget.value += targetWidget.options.step / 10; - break; - case "decrement": - targetWidget.value -= targetWidget.options.step / 10; - break; - case "randomize": - targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; - default: - break; + switch (v) { + case "increment": + current_index += 1; + break; + case "decrement": + current_index -= 1; + break; + case "randomize": + current_index = Math.floor(Math.random() * current_length); + default: + break; + } + current_index = Math.max(0, current_index); + current_index = Math.min(current_length - 1, current_index); + if (current_index >= 0) { + let value = targetWidget.options.values[current_index]; + targetWidget.value = value; + targetWidget.callback(value); + } + } else { //number + let min = targetWidget.options.min; + let max = targetWidget.options.max; + // limit to something that javascript can handle + max = Math.min(1125899906842624, max); + min = Math.max(-1125899906842624, min); + let range = (max - min) / (targetWidget.options.step / 10); + + //adjust values based on valueControl Behaviour + switch (v) { + case "fixed": + break; + case "increment": + targetWidget.value += targetWidget.options.step / 10; + break; + case "decrement": + targetWidget.value -= targetWidget.options.step / 10; + break; + case "randomize": + targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; + default: + break; + } + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) + targetWidget.value = min; + + if (targetWidget.value > max) + targetWidget.value = max; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; - - if (targetWidget.value > max) - targetWidget.value = max; } return valueControl; }; From 7ada9e7d85f93495aa5006468a45220932f5e988 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Tue, 16 May 2023 22:55:00 +0900 Subject: [PATCH 183/190] allows touch drag --- web/scripts/app.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 87c5e30ca..ef3b44c83 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -902,7 +902,9 @@ export class ComfyApp { await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM - const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + const mainCanvas = document.createElement("canvas") + mainCanvas.style.touchAction = "none" + const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" })); canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); From 11e7168d56e0987e52d0afb620189f08bda2b454 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 11:55:16 -0400 Subject: [PATCH 184/190] Remove print. --- web/scripts/widgets.js | 1 - 1 file changed, 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 3d1acc53e..94988d0f2 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,7 +19,6 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - console.log(targetWidget); if (targetWidget.type == "combo" && v !== "fixed") { let current_index = targetWidget.options.values.indexOf(targetWidget.value); let current_length = targetWidget.options.values.length; From 4088e61aa6b8943e28ee243c0b1265c41974ef67 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 15:35:07 -0400 Subject: [PATCH 185/190] Update litegraph from upstream. --- web/lib/litegraph.core.js | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 6c81c3ffd..95f4a2735 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action) if (show_text) { ctx.textAlign = "center"; ctx.fillStyle = text_color; - ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); + ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7); } break; case "toggle": @@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); if (show_text) { ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.textAlign = "right"; @@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed(3), widget_width * 0.5, y + H * 0.7 ); @@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); } ctx.fillStyle = secondary_text_color; - ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); + ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7); ctx.fillStyle = text_color; ctx.textAlign = "right"; if (w.type == "number") { @@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action) //ctx.stroke(); ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = text_color; ctx.textAlign = "right"; From e7f2816c6f1da22e2018cf088bd45110ff265c79 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 18 May 2023 12:40:28 +0900 Subject: [PATCH 186/190] feat:Latent Save/Load (#662) * wip * latent dir * fix * fix * now working * mark todo * remove server.py changes to separate PRt --------- Co-authored-by: Lt.Dr.Data --- input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index e8b36c24a..a2c7713aa 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,8 @@ import importlib import folder_paths +import safetensors.torch as sft + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -246,6 +248,91 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") + self.type = "output" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(self.output_dir, subfolder) + + if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: + print("Saving latent outside the 'input/latents' folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"workflow": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + sft.save_file(samples, file, metadata=metadata) + + return {} + + +class LoadLatent: + input_dir = os.path.join(folder_paths.get_input_directory(), "latents") + + @classmethod + def INPUT_TYPES(s): + files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + file = folder_paths.get_annotated_filepath(latent, self.input_dir) + + latent = sft.load_file(file, device="cpu") + + return (latent, ) + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -1235,6 +1322,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = { From a7375103b9c80bb7607f85faa4afbf11ab5a5685 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:04:40 -0400 Subject: [PATCH 187/190] Some small changes to Load/SaveLatent. --- nodes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index a2c7713aa..7255621d7 100644 --- a/nodes.py +++ b/nodes.py @@ -11,6 +11,7 @@ import time from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np +import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -29,7 +30,6 @@ import importlib import folder_paths -import safetensors.torch as sft def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -307,7 +307,10 @@ class SaveLatent: file = f"{filename}_{counter:05}_.latent" file = os.path.join(full_output_folder, file) - sft.save_file(samples, file, metadata=metadata) + output = {} + output["latent_tensor"] = samples["samples"] + + safetensors.torch.save_file(output, file, metadata=metadata) return {} @@ -328,9 +331,10 @@ class LoadLatent: def load(self, latent): file = folder_paths.get_annotated_filepath(latent, self.input_dir) - latent = sft.load_file(file, device="cpu") + latent = safetensors.torch.load_file(file, device="cpu") + samples = {"samples": latent["latent_tensor"]} - return (latent, ) + return (samples, ) class CheckpointLoader: From faf899ad5ae32f770f0dae6a9df457e81d2b5c38 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:43:59 -0400 Subject: [PATCH 188/190] LoadLatent and SaveLatent should behave like the LoadImage and SaveImage. --- folder_paths.py | 33 +++++++ input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++-------------- 3 files changed, 55 insertions(+), 68 deletions(-) delete mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..28f117824 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -147,4 +147,37 @@ def get_filename_list(folder_name): output_list.update(filter_files_extensions(recursive_search(x), folders[1])) return sorted(list(output_list)) +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + def compute_vars(input, image_width, image_height): + input = input.replace("%width%", str(image_width)) + input = input.replace("%height%", str(image_height)) + return input + + filename_prefix = compute_vars(filename_prefix, image_width, image_height) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(output_dir, subfolder) + + if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: + print("Saving image outside the output folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + return full_output_folder, filename, counter, subfolder, filename_prefix diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here deleted file mode 100644 index e69de29bb..000000000 diff --git a/nodes.py b/nodes.py index 7255621d7..7b450df38 100644 --- a/nodes.py +++ b/nodes.py @@ -251,13 +251,12 @@ class VAEEncodeForInpaint: class SaveLatent: def __init__(self): - self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") - self.type = "output" + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () @@ -268,31 +267,7 @@ class SaveLatent: CATEGORY = "_for_testing" def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving latent outside the 'input/latents' folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) # support save metadata for latent sharing prompt_info = "" @@ -316,11 +291,10 @@ class SaveLatent: class LoadLatent: - input_dir = os.path.join(folder_paths.get_input_directory(), "latents") - @classmethod def INPUT_TYPES(s): - files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } CATEGORY = "_for_testing" @@ -329,13 +303,25 @@ class LoadLatent: FUNCTION = "load" def load(self, latent): - file = folder_paths.get_annotated_filepath(latent, self.input_dir) - - latent = safetensors.torch.load_file(file, device="cpu") + latent_path = folder_paths.get_annotated_filepath(latent) + latent = safetensors.torch.load_file(latent_path, device="cpu") samples = {"samples": latent["latent_tensor"]} - return (samples, ) + @classmethod + def IS_CHANGED(s, latent): + image_path = folder_paths.get_annotated_filepath(latent) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, latent): + if not folder_paths.exists_annotated_filepath(latent): + return "Invalid latent file: {}".format(latent) + return True + class CheckpointLoader: @classmethod @@ -1020,39 +1006,7 @@ class SaveImage: CATEGORY = "image" def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - filename_prefix = compute_vars(filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving image outside the output folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() for image in images: i = 255. * image.cpu().numpy() From 62a371e12b4763bf6f9aeb42ff4928138df6ae26 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 18 May 2023 02:41:21 -0400 Subject: [PATCH 189/190] Load workflow from latent file. --- nodes.py | 2 +- web/scripts/app.js | 7 ++++++- web/scripts/pnginfo.js | 16 ++++++++++++++++ web/scripts/ui.js | 2 +- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 7b450df38..3c61cd2ec 100644 --- a/nodes.py +++ b/nodes.py @@ -274,7 +274,7 @@ class SaveLatent: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"workflow": prompt_info} + metadata = {"prompt": prompt_info} if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) diff --git a/web/scripts/app.js b/web/scripts/app.js index ef3b44c83..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111 } from "./pnginfo.js"; +import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -1308,6 +1308,11 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); + } else if (file.name?.endsWith(".latent")) { + const info = await getLatentMetadata(file); + if (info.workflow) { + this.loadGraphData(JSON.parse(info.workflow)); + } } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 209b562a6..8ddb7a1c5 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,6 +47,22 @@ export function getPngMetadata(file) { }); } +export function getLatentMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const safetensorsData = new Uint8Array(event.target.result); + const dataView = new DataView(safetensorsData.buffer); + let header_size = dataView.getUint32(0, true); + let offset = 8; + let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + r(header.__metadata__); + }; + + reader.readAsArrayBuffer(file); + }); +} + export async function importA1111(graph, parameters) { const p = parameters.lastIndexOf("\nSteps:"); if (p > -1) { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 77517aec1..2c9043d00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png", + accept: ".json,image/png,.latent", style: { display: "none" }, parent: document.body, onchange: () => { From 8bbd9815a976ef43e2665d45c5afb4a21c06c831 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 02:15:32 -0400 Subject: [PATCH 190/190] Support loading fp16 latent files. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 3c61cd2ec..878e0b955 100644 --- a/nodes.py +++ b/nodes.py @@ -305,7 +305,7 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"]} + samples = {"samples": latent["latent_tensor"].float()} return (samples, ) @classmethod