From 20ae48515e423c102a8845079cf05d530ffdb84d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 26 Mar 2023 15:01:34 +0100 Subject: [PATCH 01/83] Add setting to save menu position Add anchoring to side when resizing Fix losing menu when resizing --- web/scripts/ui.js | 170 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 136 insertions(+), 34 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 94f3c528a..d92e2cfa7 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -35,21 +35,92 @@ function $el(tag, propsOrChildren, children) { return element; } -function dragElement(dragEl) { +function dragElement(dragEl, settings) { var posDiffX = 0, posDiffY = 0, posStartX = 0, posStartY = 0, newPosX = 0, newPosY = 0; - if (dragEl.getElementsByClassName('drag-handle')[0]) { + if (dragEl.getElementsByClassName("drag-handle")[0]) { // if present, the handle is where you move the DIV from: - dragEl.getElementsByClassName('drag-handle')[0].onmousedown = dragMouseDown; + dragEl.getElementsByClassName("drag-handle")[0].onmousedown = dragMouseDown; } else { // otherwise, move the DIV from anywhere inside the DIV: dragEl.onmousedown = dragMouseDown; } + function ensureInBounds() { + newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); + newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); + + console.log(newPosX, newPosY) + + positionElement(); + } + + function positionElement() { + const halfWidth = document.body.clientWidth / 2; + const halfHeight = document.body.clientHeight / 2; + + const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth; + const anchorBottom = newPosY + dragEl.clientHeight / 2 > halfHeight; + + // set the element's new position: + if (anchorRight) { + dragEl.style.left = "unset"; + dragEl.style.right = document.body.clientWidth - newPosX - dragEl.clientWidth + "px"; + } else { + dragEl.style.left = newPosX + "px"; + dragEl.style.right = "unset"; + } + if (anchorBottom) { + dragEl.style.top = "unset"; + dragEl.style.bottom = document.body.clientHeight - newPosY - dragEl.clientHeight + "px"; + } else { + dragEl.style.top = newPosY + "px"; + dragEl.style.bottom = "unset"; + } + + if (savePos) { + localStorage.setItem( + "Comfy.MenuPosition", + JSON.stringify({ + left: dragEl.style.left, + right: dragEl.style.right, + top: dragEl.style.top, + bottom: dragEl.style.bottom, + }) + ); + } + } + + function restorePos() { + let pos = localStorage.getItem("Comfy.MenuPosition"); + if (pos) { + pos = JSON.parse(pos); + dragEl.style.left = pos.left; + dragEl.style.right = pos.right; + dragEl.style.top = pos.top; + dragEl.style.bottom = pos.bottom; + ensureInBounds(); + } + } + + let savePos = undefined; + settings.addSetting({ + id: "Comfy.MenuPosition", + name: "Save menu position", + type: "boolean", + defaultValue: savePos, + onChange(value) { + if (savePos === undefined && value) { + restorePos(); + } + savePos = value; + }, + }); + function dragMouseDown(e) { e = e || window.event; e.preventDefault(); @@ -64,18 +135,27 @@ function dragElement(dragEl) { function elementDrag(e) { e = e || window.event; e.preventDefault(); + + dragEl.classList.add("comfy-menu-manual-pos"); + // calculate the new cursor position: posDiffX = e.clientX - posStartX; posDiffY = e.clientY - posStartY; posStartX = e.clientX; posStartY = e.clientY; - newPosX = Math.min((document.body.clientWidth - dragEl.clientWidth), Math.max(0, (dragEl.offsetLeft + posDiffX))); - newPosY = Math.min((document.body.clientHeight - dragEl.clientHeight), Math.max(0, (dragEl.offsetTop + posDiffY))); - // set the element's new position: - dragEl.style.top = newPosY + "px"; - dragEl.style.left = newPosX + "px"; + + newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft + posDiffX)); + newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop + posDiffY)); + + positionElement(); } + window.addEventListener("resize", () => { + if (dragEl.classList.contains("comfy-menu-manual-pos")) { + ensureInBounds(); + } + }); + function closeDragElement() { // stop moving when mouse button is released: document.onmouseup = null; @@ -305,34 +385,52 @@ export class ComfyUI { $el("span", { $: (q) => (this.queueSize = q) }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), - $el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount) }), + $el("button.comfy-queue-btn", { + textContent: "Queue Prompt", + onclick: () => app.queuePrompt(0, this.batchCount), + }), $el("div", {}, [ - $el("label", { innerHTML: "Extra options"}, [ - $el("input", { type: "checkbox", - onchange: (i) => { - document.getElementById('extraOptions').style.display = i.srcElement.checked ? "block" : "none"; - this.batchCount = i.srcElement.checked ? document.getElementById('batchCountInputRange').value : 1; - document.getElementById('autoQueueCheckbox').checked = false; - } - }) - ]) - ]), - $el("div", { id: "extraOptions", style: { width: "100%", display: "none" }}, [ - $el("label", { innerHTML: "Batch count" }, [ - $el("input", { id: "batchCountInputNumber", type: "number", value: this.batchCount, min: "1", style: { width: "35%", "margin-left": "0.4em" }, - oninput: (i) => { - this.batchCount = i.target.value; - document.getElementById('batchCountInputRange').value = this.batchCount; - } + $el("label", { innerHTML: "Extra options" }, [ + $el("input", { + type: "checkbox", + onchange: (i) => { + document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none"; + this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1; + document.getElementById("autoQueueCheckbox").checked = false; + }, }), - $el("input", { id: "batchCountInputRange", type: "range", min: "1", max: "100", value: this.batchCount, + ]), + ]), + $el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [ + $el("label", { innerHTML: "Batch count" }, [ + $el("input", { + id: "batchCountInputNumber", + type: "number", + value: this.batchCount, + min: "1", + style: { width: "35%", "margin-left": "0.4em" }, + oninput: (i) => { + this.batchCount = i.target.value; + document.getElementById("batchCountInputRange").value = this.batchCount; + }, + }), + $el("input", { + id: "batchCountInputRange", + type: "range", + min: "1", + max: "100", + value: this.batchCount, oninput: (i) => { this.batchCount = i.srcElement.value; - document.getElementById('batchCountInputNumber').value = i.srcElement.value; - } + document.getElementById("batchCountInputNumber").value = i.srcElement.value; + }, + }), + $el("input", { + id: "autoQueueCheckbox", + type: "checkbox", + checked: false, + title: "automatically queue prompt when the queue size hits 0", }), - $el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "automatically queue prompt when the queue size hits 0", - }) ]), ]), $el("div.comfy-menu-btns", [ @@ -380,7 +478,7 @@ export class ComfyUI { $el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }), ]); - dragElement(this.menuContainer); + dragElement(this.menuContainer, this.settings); this.setStatus({ exec_info: { queue_remaining: "X" } }); } @@ -388,10 +486,14 @@ export class ComfyUI { setStatus(status) { this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR"); if (status) { - if (this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && document.getElementById('autoQueueCheckbox').checked) { + if ( + this.lastQueueSize != 0 && + status.exec_info.queue_remaining == 0 && + document.getElementById("autoQueueCheckbox").checked + ) { app.queuePrompt(0, this.batchCount); } - this.lastQueueSize = status.exec_info.queue_remaining + this.lastQueueSize = status.exec_info.queue_remaining; } } } From 716d8e746af6f7ce24f18f7641f055713fb32a42 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 26 Mar 2023 15:03:57 +0100 Subject: [PATCH 02/83] Remove log --- web/scripts/ui.js | 2 -- 1 file changed, 2 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index d92e2cfa7..117f4369e 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -54,8 +54,6 @@ function dragElement(dragEl, settings) { newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); - console.log(newPosX, newPosY) - positionElement(); } From 0b1e85fbea14b3a9ed6269b53ec921dc4eb02668 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 26 Mar 2023 15:10:38 +0100 Subject: [PATCH 03/83] Add manual flag when restoring pos --- web/scripts/ui.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 117f4369e..404aae26d 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -101,6 +101,7 @@ function dragElement(dragEl, settings) { dragEl.style.right = pos.right; dragEl.style.top = pos.top; dragEl.style.bottom = pos.bottom; + dragEl.classList.add("comfy-menu-manual-pos"); ensureInBounds(); } } From 04b42bad87c2ec91a247f63378cc97718ebd9dbc Mon Sep 17 00:00:00 2001 From: hnmr293 Date: Thu, 30 Mar 2023 21:50:35 +0900 Subject: [PATCH 04/83] allow converting optional widgets to inputs --- web/extensions/core/widgetInputs.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index ff9227d28..7e6688261 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -101,7 +101,7 @@ app.registerExtension({ callback: () => convertToWidget(this, w), }); } else { - const config = nodeData?.input?.required[w.name] || [w.type, w.options || {}]; + const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; if (isConvertableWidget(w, config)) { toInput.push({ content: `Convert ${w.name} to input`, From c93dc2fb89d83ef04837d6a4a712870a2a13eac7 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 30 Mar 2023 20:14:01 +0100 Subject: [PATCH 05/83] Remove bottom anchor --- web/scripts/ui.js | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index fc37fd3dd..8c7f096d1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -50,19 +50,24 @@ function dragElement(dragEl, settings) { dragEl.onmousedown = dragMouseDown; } + // When the element resizes (e.g. view queue) ensure it is still in the windows bounds + const resizeObserver = new ResizeObserver(() => { + ensureInBounds(); + }).observe(dragEl); + function ensureInBounds() { + if (dragEl.classList.contains("comfy-menu-manual-pos")) { newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); positionElement(); } + } function positionElement() { const halfWidth = document.body.clientWidth / 2; - const halfHeight = document.body.clientHeight / 2; const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth; - const anchorBottom = newPosY + dragEl.clientHeight / 2 > halfHeight; // set the element's new position: if (anchorRight) { @@ -72,22 +77,15 @@ function dragElement(dragEl, settings) { dragEl.style.left = newPosX + "px"; dragEl.style.right = "unset"; } - if (anchorBottom) { - dragEl.style.top = "unset"; - dragEl.style.bottom = document.body.clientHeight - newPosY - dragEl.clientHeight + "px"; - } else { dragEl.style.top = newPosY + "px"; dragEl.style.bottom = "unset"; - } if (savePos) { localStorage.setItem( "Comfy.MenuPosition", JSON.stringify({ - left: dragEl.style.left, - right: dragEl.style.right, - top: dragEl.style.top, - bottom: dragEl.style.bottom, + x: dragEl.offsetLeft, + y: dragEl.offsetTop, }) ); } @@ -97,11 +95,9 @@ function dragElement(dragEl, settings) { let pos = localStorage.getItem("Comfy.MenuPosition"); if (pos) { pos = JSON.parse(pos); - dragEl.style.left = pos.left; - dragEl.style.right = pos.right; - dragEl.style.top = pos.top; - dragEl.style.bottom = pos.bottom; - dragEl.classList.add("comfy-menu-manual-pos"); + newPosX = pos.x; + newPosY = pos.y; + positionElement(); ensureInBounds(); } } @@ -150,9 +146,7 @@ function dragElement(dragEl, settings) { } window.addEventListener("resize", () => { - if (dragEl.classList.contains("comfy-menu-manual-pos")) { ensureInBounds(); - } }); function closeDragElement() { From 3a5bcdf8b9a141f7e629ac3a7174f20af3aac5a1 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 30 Mar 2023 20:15:12 +0100 Subject: [PATCH 06/83] Formatting --- web/scripts/ui.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 8c7f096d1..194d8e2dd 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -57,11 +57,11 @@ function dragElement(dragEl, settings) { function ensureInBounds() { if (dragEl.classList.contains("comfy-menu-manual-pos")) { - newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); - newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); + newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); + newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); - positionElement(); - } + positionElement(); + } } function positionElement() { From 722801ed2da85478863a1fb9950450897eb7b0b6 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 30 Mar 2023 20:15:48 +0100 Subject: [PATCH 07/83] Formatting --- web/scripts/ui.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 194d8e2dd..587f4e529 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -66,7 +66,6 @@ function dragElement(dragEl, settings) { function positionElement() { const halfWidth = document.body.clientWidth / 2; - const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth; // set the element's new position: @@ -77,8 +76,9 @@ function dragElement(dragEl, settings) { dragEl.style.left = newPosX + "px"; dragEl.style.right = "unset"; } - dragEl.style.top = newPosY + "px"; - dragEl.style.bottom = "unset"; + + dragEl.style.top = newPosY + "px"; + dragEl.style.bottom = "unset"; if (savePos) { localStorage.setItem( From 61ec3c9d5d3e11f94682170be1454221512899c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 31 Mar 2023 13:04:39 -0400 Subject: [PATCH 08/83] Add a way to pass options to the transformers blocks. --- comfy/ldm/models/diffusion/ddim.py | 14 +++++++------- comfy/ldm/models/diffusion/ddpm.py | 18 +++++++++--------- comfy/ldm/modules/attention.py | 10 +++++----- .../modules/diffusionmodules/openaimodel.py | 13 +++++++------ comfy/samplers.py | 7 +++++-- 5 files changed, 33 insertions(+), 29 deletions(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 5e2d73645..e00ffd3f5 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -78,7 +78,7 @@ class DDIMSampler(object): dynamic_threshold=None, ucg_schedule=None, denoise_function=None, - cond_concat=None, + extra_args=None, to_zero=True, end_step=None, **kwargs @@ -101,7 +101,7 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, - cond_concat=cond_concat, + extra_args=extra_args, to_zero=to_zero, end_step=end_step ) @@ -174,7 +174,7 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=None, - cond_concat=None + extra_args=None ) return samples, intermediates @@ -185,7 +185,7 @@ class DDIMSampler(object): mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): + ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -225,7 +225,7 @@ class DDIMSampler(object): corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) + dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -249,11 +249,11 @@ class DDIMSampler(object): def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None, denoise_function=None, cond_concat=None): + dynamic_threshold=None, denoise_function=None, extra_args=None): b, *_, device = *x.shape, x.device if denoise_function is not None: - model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) + model_output = denoise_function(self.model.apply_model, x, t, **extra_args) elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output = self.model.apply_model(x, t, c) else: diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 42ed2add7..6af961242 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}): if self.conditioning_key is None: - out = self.diffusion_model(x, t, control=control) + out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t, control=control) + out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'crossattn': if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) @@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module): # TorchScript changes names of the arguments # with argument cc defined as context=cc scripted model will produce # an error: RuntimeError: forward() is missing value for argument 'argument_3'. - out = self.scripted_diffusion_model(x, t, cc, control=control) + out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options) else: - out = self.diffusion_model(x, t, context=cc, control=control) + out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, control=control) + out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'hybrid-adm': assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'crossattn-adm': assert c_adm is not None cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) + out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'adm': cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc, control=control) + out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options) else: raise NotImplementedError() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 23b047342..25051b339 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -504,10 +504,10 @@ class BasicTransformerBlock(nn.Module): self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint - def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + def forward(self, x, context=None, transformer_options={}): + return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) - def _forward(self, x, context=None): + def _forward(self, x, context=None, transformer_options={}): x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x @@ -557,7 +557,7 @@ class SpatialTransformer(nn.Module): self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear - def forward(self, x, context=None): + def forward(self, x, context=None, transformer_options={}): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] @@ -570,7 +570,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) + x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 09ab1a066..7b2f5b531 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context=None): + def forward(self, x, emb, context=None, transformer_options={}): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): - x = layer(x, context) + x = layer(x, context, transformer_options) else: x = layer(x) return x @@ -753,7 +753,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -762,6 +762,7 @@ class UNetModel(nn.Module): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ + transformer_options["original_shape"] = list(x.shape) assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -775,13 +776,13 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): - h = module(h, emb, context) + h = module(h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context) + h = self.middle_block(h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -793,7 +794,7 @@ class UNetModel(nn.Module): hsp += ctrl h = th.cat([h, hsp], dim=1) del hsp - h = module(h, emb, context) + h = module(h, emb, context, transformer_options) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/samplers.py b/comfy/samplers.py index 66218f887..40d5d332b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module): #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + if 'transformer_options' in model_options: + c['transformer_options'] = model_options['transformer_options'] + output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x @@ -467,7 +470,7 @@ class KSampler: x_T=z_enc, x0=latent_image, denoise_function=sampling_function, - cond_concat=cond_concat, + extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, end_step=sigmas.shape[0] - 1) From 1716aaa7a6823f3eb7542fefac3257e8f2d8191c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 31 Mar 2023 18:04:53 +0100 Subject: [PATCH 09/83] Swap order to prevent being cleared --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 2aabd29e7..2d55e885e 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -225,10 +225,10 @@ class ComfyList { $el("button", { textContent: "Load", onclick: () => { + app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); if (item.outputs) { app.nodeOutputs = item.outputs; } - app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); }, }), $el("button", { From 06c2c19b5a2db59dc28ff48a817d399c9148576e Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 31 Mar 2023 20:35:26 +0100 Subject: [PATCH 10/83] Clone default graph before using --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index b29981091..501c7ea65 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -802,7 +802,7 @@ class ComfyApp { this.clean(); if (!graphData) { - graphData = defaultGraph; + graphData = structuredClone(defaultGraph); } // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now From 18a6c1db3335c6898181920aa6c9bb5b060fd85f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 31 Mar 2023 17:19:58 -0400 Subject: [PATCH 11/83] Add a TomePatchModel node to the _for_testing section. Tome increases sampling speed at the expense of quality. --- comfy/ldm/modules/attention.py | 15 ++++- comfy/ldm/modules/tomesd.py | 117 +++++++++++++++++++++++++++++++++ comfy/samplers.py | 17 ++--- comfy/sd.py | 9 +++ nodes.py | 19 +++++- 5 files changed, 166 insertions(+), 11 deletions(-) create mode 100644 comfy/ldm/modules/tomesd.py diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 25051b339..07553627c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention import model_management +from . import tomesd if model_management.xformers_enabled(): import xformers @@ -508,8 +509,18 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): - x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x - x = self.attn2(self.norm2(x), context=context) + x + n = self.norm1(x) + if "tomesd" in transformer_options: + m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) + n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + else: + n = self.attn1(n, context=context if self.disable_self_attn else None) + + x += n + n = self.norm2(x) + n = self.attn2(n, context=context) + + x += n x = self.ff(self.norm3(x)) + x return x diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py new file mode 100644 index 000000000..5bf1acec9 --- /dev/null +++ b/comfy/ldm/modules/tomesd.py @@ -0,0 +1,117 @@ + + +import torch +from typing import Tuple, Callable +import math + +def do_nothing(x: torch.Tensor, mode:str=None): + return x + + +def bipartite_soft_matching_random2d(metric: torch.Tensor, + w: int, h: int, sx: int, sy: int, r: int, + no_rand: bool = False) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + """ + B, N, _ = metric.shape + + if r <= 0: + return do_nothing, do_nothing + + with torch.no_grad(): + + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device) + + if no_rand: + rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) + + idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) + idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) + rand_idx = idx_buffer.argsort(dim=1) + + num_dst = int((1 / (sx*sy)) * N) + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def get_functions(x, ratio, original_shape): + b, c, original_h, original_w = original_shape + original_tokens = original_h * original_w + downsample = int(math.sqrt(original_tokens // x.shape[1])) + stride_x = 2 + stride_y = 2 + max_downsample = 1 + + if downsample <= max_downsample: + w = original_w // downsample + h = original_h // downsample + r = int(x.shape[1] * ratio) + no_rand = True + m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) + return m, u + + nothing = lambda y: y + return nothing, nothing diff --git a/comfy/samplers.py b/comfy/samplers.py index 40d5d332b..15e78bbd7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out['c_concat'] = [torch.cat(c_concat)] return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -195,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) return uncond + (cond - uncond) * cond_scale @@ -212,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) return out @@ -221,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) if denoise_mask is not None: out *= denoise_mask @@ -333,7 +333,7 @@ class KSampler: "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] - def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): + def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v": @@ -353,6 +353,7 @@ class KSampler: self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) self.denoise = denoise + self.model_options = model_options def _calculate_sigmas(self, steps): sigmas = None @@ -421,7 +422,7 @@ class KSampler: else: precision_scope = contextlib.nullcontext - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None if hasattr(self.model, 'concat_keys'): diff --git a/comfy/sd.py b/comfy/sd.py index 2e1ae8409..2a38ceb15 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,5 +1,6 @@ import torch import contextlib +import copy import sd1_clip import sd2_clip @@ -274,12 +275,20 @@ class ModelPatcher: self.model = model self.patches = [] self.backup = {} + self.model_options = {"transformer_options":{}} def clone(self): n = ModelPatcher(self.model) n.patches = self.patches[:] + n.model_options = copy.deepcopy(self.model_options) return n + def set_model_tomesd(self, ratio): + self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} + + def model_dtype(self): + return self.model.diffusion_model.dtype + def add_patches(self, patches, strength=1.0): p = {} model_sd = self.model.state_dict() diff --git a/nodes.py b/nodes.py index 6fb7f0175..e69832c56 100644 --- a/nodes.py +++ b/nodes.py @@ -254,6 +254,22 @@ class LoraLoader: model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) +class TomePatchModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, ratio): + m = model.clone() + m.set_model_tomesd(ratio) + return (m, ) + class VAELoader: @classmethod def INPUT_TYPES(s): @@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, model_management.load_controlnet_gpu(control_net_models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) else: #other samplers pass @@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = { "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, + "TomePatchModel": TomePatchModel, } def load_custom_node(module_path): From 0d972b85e616979c5832a15341972ba861197b4e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 31 Mar 2023 18:36:18 -0400 Subject: [PATCH 12/83] This seems to give better quality in tome. --- comfy/ldm/modules/tomesd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 5bf1acec9..1eafcd0aa 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -109,7 +109,7 @@ def get_functions(x, ratio, original_shape): w = original_w // downsample h = original_h // downsample r = int(x.shape[1] * ratio) - no_rand = True + no_rand = False m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) return m, u From 313f1f83a6f41ccff589663850c22c4e71e2819f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 1 Apr 2023 12:44:29 +0100 Subject: [PATCH 13/83] Tweak server/custom node load order - Load custom nodes after creating server - Add routes after loading custom nodes Custom nodes can now add routes via PromptServer.instance --- main.py | 4 +++- nodes.py | 6 +++--- server.py | 6 ++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index c9809137a..824530fb1 100644 --- a/main.py +++ b/main.py @@ -40,6 +40,7 @@ if __name__ == "__main__": except: pass +from nodes import init_custom_nodes import execution import server import folder_paths @@ -98,6 +99,8 @@ if __name__ == "__main__": server = server.PromptServer(loop) q = execution.PromptQueue(server) + init_custom_nodes() + server.add_routes() hijack_progress(server) threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() @@ -113,7 +116,6 @@ if __name__ == "__main__": except: address = '127.0.0.1' - dont_print = False if '--dont-print-server' in sys.argv: dont_print = True diff --git a/nodes.py b/nodes.py index 6fb7f0175..b422f2cbb 100644 --- a/nodes.py +++ b/nodes.py @@ -1050,6 +1050,6 @@ def load_custom_nodes(): if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue load_custom_node(module_path) -load_custom_nodes() - -load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) \ No newline at end of file +def init_custom_nodes(): + load_custom_nodes() + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) \ No newline at end of file diff --git a/server.py b/server.py index 80fb2dc72..963daefff 100644 --- a/server.py +++ b/server.py @@ -42,6 +42,7 @@ class PromptServer(): self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") routes = web.RouteTableDef() + self.routes = routes self.last_node_id = None self.client_id = None @@ -239,8 +240,9 @@ class PromptServer(): self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) - - self.app.add_routes(routes) + + def add_routes(self): + self.app.add_routes(self.routes) self.app.add_routes([ web.static('/', self.web_root), ]) From 9586de9dc8eccb7f9c4934b7661a90fb208a81a8 Mon Sep 17 00:00:00 2001 From: flyingshutter Date: Sat, 1 Apr 2023 17:30:47 +0200 Subject: [PATCH 14/83] fix client freeze on connect reroutes in a circle --- web/extensions/core/rerouteNode.js | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 7188dfd26..1342cae92 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -43,8 +43,15 @@ app.registerExtension({ const node = app.graph.getNodeById(link.origin_id); const type = node.constructor.type; if (type === "Reroute") { + if (node === this) { + // We've found a circle + currentNode.disconnectInput(link.target_slot); + currentNode = null; + } + else { // Move the previous node - currentNode = node; + currentNode = node; + } } else { // We've found the end inputNode = currentNode; From 178fc763635fc6784a6c1cb00ee08012c7bd72fe Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 1 Apr 2023 18:46:05 +0100 Subject: [PATCH 15/83] Added a queue for the queue action --- web/scripts/app.js | 63 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 501c7ea65..5af6d5fc0 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -5,6 +5,15 @@ import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js"; class ComfyApp { + /** + * List of {number, batchCount} entries to queue + */ + #queueItems = []; + /** + * If the queue is currently being processed + */ + #processingQueue = false; + constructor() { this.ui = new ComfyUI(this); this.extensions = []; @@ -915,31 +924,47 @@ class ComfyApp { } async queuePrompt(number, batchCount = 1) { - for (let i = 0; i < batchCount; i++) { - const p = await this.graphToPrompt(); + this.#queueItems.push({ number, batchCount }); - try { - await api.queuePrompt(number, p); - } catch (error) { - this.ui.dialog.show(error.response || error.toString()); - return; - } + // Only have one action process the items so each one gets a unique seed correctly + if (this.#processingQueue) { + return; + } + + this.#processingQueue = true; + try { + while (this.#queueItems.length) { + ({ number, batchCount } = this.#queueItems.pop()); - for (const n of p.workflow.nodes) { - const node = graph.getNodeById(n.id); - if (node.widgets) { - for (const widget of node.widgets) { - // Allow widgets to run callbacks after a prompt has been queued - // e.g. random seed after every gen - if (widget.afterQueued) { - widget.afterQueued(); + for (let i = 0; i < batchCount; i++) { + const p = await this.graphToPrompt(); + + try { + await api.queuePrompt(number, p); + } catch (error) { + this.ui.dialog.show(error.response || error.toString()); + break; + } + + for (const n of p.workflow.nodes) { + const node = graph.getNodeById(n.id); + if (node.widgets) { + for (const widget of node.widgets) { + // Allow widgets to run callbacks after a prompt has been queued + // e.g. random seed after every gen + if (widget.afterQueued) { + widget.afterQueued(); + } + } } } + + this.canvas.draw(true, true); + await this.ui.queue.update(); } } - - this.canvas.draw(true, true); - await this.ui.queue.update(); + } finally { + this.#processingQueue = false; } } From 809bcc8cebba6d53565fd6acaac4dd0314054373 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Apr 2023 23:19:15 -0400 Subject: [PATCH 16/83] Add support for unCLIP SD2.x models. See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section. --- comfy/clip_vision.py | 62 +++++++++ comfy/clip_vision_config_h.json | 18 +++ .../clip_vision_config_vitl.json | 7 +- comfy/ldm/models/diffusion/ddpm.py | 72 ++++++++++ .../models/diffusion/dpm_solver/dpm_solver.py | 11 +- .../models/diffusion/dpm_solver/sampler.py | 24 ++-- .../modules/diffusionmodules/openaimodel.py | 19 +++ comfy/ldm/modules/diffusionmodules/util.py | 10 +- .../ldm/modules/encoders/kornia_functions.py | 59 +++++++++ comfy/ldm/modules/encoders/modules.py | 125 ++++++++++++++++-- .../ldm/modules/encoders/noise_aug_modules.py | 35 +++++ comfy/samplers.py | 45 ++++++- comfy/sd.py | 93 +++++++------ comfy/utils.py | 42 ++++++ comfy_extras/clip_vision.py | 32 ----- comfy_extras/nodes_upscale_model.py | 3 +- nodes.py | 49 ++++++- 17 files changed, 593 insertions(+), 113 deletions(-) create mode 100644 comfy/clip_vision.py create mode 100644 comfy/clip_vision_config_h.json rename comfy_extras/clip_vision_config.json => comfy/clip_vision_config_vitl.json (70%) create mode 100644 comfy/ldm/modules/encoders/kornia_functions.py create mode 100644 comfy/ldm/modules/encoders/noise_aug_modules.py delete mode 100644 comfy_extras/clip_vision.py diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py new file mode 100644 index 000000000..cb29df432 --- /dev/null +++ b/comfy/clip_vision.py @@ -0,0 +1,62 @@ +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor +from .utils import load_torch_file, transformers_convert +import os + +class ClipVisionModel(): + def __init__(self, json_config): + config = CLIPVisionConfig.from_json_file(json_config) + self.model = CLIPVisionModelWithProjection(config) + self.processor = CLIPImageProcessor(crop_size=224, + do_center_crop=True, + do_convert_rgb=True, + do_normalize=True, + do_resize=True, + image_mean=[ 0.48145466,0.4578275,0.40821073], + image_std=[0.26862954,0.26130258,0.27577711], + resample=3, #bicubic + size=224) + + def load_sd(self, sd): + self.model.load_state_dict(sd, strict=False) + + def encode_image(self, image): + inputs = self.processor(images=[image[0]], return_tensors="pt") + outputs = self.model(**inputs) + return outputs + +def convert_to_transformers(sd): + sd_k = sd.keys() + if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k: + keys_to_replace = { + "embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding", + "embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight", + "embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight", + "embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias", + "embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight", + "embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias", + "embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight", + } + + for x in keys_to_replace: + if x in sd_k: + sd[keys_to_replace[x]] = sd.pop(x) + + if "embedder.model.visual.proj" in sd_k: + sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1) + + sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32) + return sd + +def load_clipvision_from_sd(sd): + sd = convert_to_transformers(sd) + if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") + else: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + clip = ClipVisionModel(json_config) + clip.load_sd(sd) + return clip + +def load(ckpt_path): + sd = load_torch_file(ckpt_path) + return load_clipvision_from_sd(sd) diff --git a/comfy/clip_vision_config_h.json b/comfy/clip_vision_config_h.json new file mode 100644 index 000000000..bb71be419 --- /dev/null +++ b/comfy/clip_vision_config_h.json @@ -0,0 +1,18 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "gelu", + "hidden_size": 1280, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 5120, + "layer_norm_eps": 1e-05, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 32, + "patch_size": 14, + "projection_dim": 1024, + "torch_dtype": "float32" +} diff --git a/comfy_extras/clip_vision_config.json b/comfy/clip_vision_config_vitl.json similarity index 70% rename from comfy_extras/clip_vision_config.json rename to comfy/clip_vision_config_vitl.json index 0e4db13d9..c59b8ed5a 100644 --- a/comfy_extras/clip_vision_config.json +++ b/comfy/clip_vision_config_vitl.json @@ -1,8 +1,4 @@ { - "_name_or_path": "openai/clip-vit-large-patch14", - "architectures": [ - "CLIPVisionModel" - ], "attention_dropout": 0.0, "dropout": 0.0, "hidden_act": "quick_gelu", @@ -18,6 +14,5 @@ "num_hidden_layers": 24, "patch_size": 14, "projection_dim": 768, - "torch_dtype": "float32", - "transformers_version": "4.24.0" + "torch_dtype": "float32" } diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 6af961242..d3f0eb2b2 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): log = super().log_images(*args, **kwargs) log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') return log + + +class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): + def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5, + freeze_embedder=True, noise_aug_config=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.embed_key = embedding_key + self.embedding_dropout = embedding_dropout + # self._init_embedder(embedder_config, freeze_embedder) + self._init_noise_aug(noise_aug_config) + + def _init_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def _init_noise_aug(self, config): + if config is not None: + # use the KARLO schedule for noise augmentation on CLIP image embeddings + noise_augmentor = instantiate_from_config(config) + assert isinstance(noise_augmentor, nn.Module) + noise_augmentor = noise_augmentor.eval() + noise_augmentor.train = disabled_train + self.noise_augmentor = noise_augmentor + else: + self.noise_augmentor = None + + def get_input(self, batch, k, cond_key=None, bs=None, **kwargs): + outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs) + z, c = outputs[0], outputs[1] + img = batch[self.embed_key][:bs] + img = rearrange(img, 'b h w c -> b c h w') + c_adm = self.embedder(img) + if self.noise_augmentor is not None: + c_adm, noise_level_emb = self.noise_augmentor(c_adm) + # assume this gives embeddings of noise levels + c_adm = torch.cat((c_adm, noise_level_emb), 1) + if self.training: + c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0], + device=c_adm.device)[:, None]) * c_adm + all_conds = {"c_crossattn": [c], "c_adm": c_adm} + noutputs = [z, all_conds] + noutputs.extend(outputs[2:]) + return noutputs + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, **kwargs): + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True, + return_original_cond=True) + log["inputs"] = x + log["reconstruction"] = xrec + assert self.model.conditioning_key is not None + assert self.cond_stage_key in ["caption", "txt"] + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', '')) + unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.) + + uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext + with ema_scope(f"Sampling"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True, + ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.), + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_, ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py b/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3c..da8d41f9c 100644 --- a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -307,7 +307,16 @@ def model_wrapper( else: x_in = torch.cat([x] * 2) t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([unconditional_condition[k], condition[k]]) + else: + c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) return noise_uncond + guidance_scale * (noise - noise_uncond) diff --git a/comfy/ldm/models/diffusion/dpm_solver/sampler.py b/comfy/ldm/models/diffusion/dpm_solver/sampler.py index 4270c618a..e4d0d0a38 100644 --- a/comfy/ldm/models/diffusion/dpm_solver/sampler.py +++ b/comfy/ldm/models/diffusion/dpm_solver/sampler.py @@ -3,7 +3,6 @@ import torch from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - MODEL_TYPES = { "eps": "noise", "v": "v" @@ -51,12 +50,20 @@ class DPMSolverSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + if isinstance(ctmp, torch.Tensor): + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + if isinstance(conditioning, torch.Tensor): + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") # sampling C, H, W = shape @@ -83,6 +90,7 @@ class DPMSolverSampler(object): ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, + lower_order_final=True) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 7b2f5b531..8a4e8b3e1 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -409,6 +409,15 @@ class QKVAttention(nn.Module): return count_flops_attn(model, _x, y) +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. @@ -470,6 +479,7 @@ class UNetModel(nn.Module): num_attention_blocks=None, disable_middle_self_attn=False, use_linear_in_transformer=False, + adm_in_channels=None, ): super().__init__() if use_spatial_transformer: @@ -538,6 +548,15 @@ class UNetModel(nn.Module): elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) else: raise ValueError() diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 637363dfe..daf35da7b 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": @@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) + def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -267,4 +275,4 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise() diff --git a/comfy/ldm/modules/encoders/kornia_functions.py b/comfy/ldm/modules/encoders/kornia_functions.py new file mode 100644 index 000000000..912314cd7 --- /dev/null +++ b/comfy/ldm/modules/encoders/kornia_functions.py @@ -0,0 +1,59 @@ + + +from typing import List, Tuple, Union + +import torch +import torch.nn as nn + +#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py + +def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + r"""Normalize an image/video tensor with mean and standard deviation. + .. math:: + \text{input[channel] = (input[channel] - mean[channel]) / std[channel]} + Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, + Args: + data: Image tensor of size :math:`(B, C, *)`. + mean: Mean for each channel. + std: Standard deviations for each channel. + Return: + Normalised tensor with same size as input :math:`(B, C, *)`. + Examples: + >>> x = torch.rand(1, 4, 3, 3) + >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.])) + >>> out.shape + torch.Size([1, 4, 3, 3]) + >>> x = torch.rand(1, 4, 3, 3) + >>> mean = torch.zeros(4) + >>> std = 255. * torch.ones(4) + >>> out = normalize(x, mean, std) + >>> out.shape + torch.Size([1, 4, 3, 3]) + """ + shape = data.shape + if len(mean.shape) == 0 or mean.shape[0] == 1: + mean = mean.expand(shape[1]) + if len(std.shape) == 0 or std.shape[0] == 1: + std = std.expand(shape[1]) + + # Allow broadcast on channel dimension + if mean.shape and mean.shape[0] != 1: + if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: + raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") + + # Allow broadcast on channel dimension + if std.shape and std.shape[0] != 1: + if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: + raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") + + mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) + std = torch.as_tensor(std, device=data.device, dtype=data.dtype) + + if mean.shape: + mean = mean[..., :, None] + if std.shape: + std = std[..., :, None] + + out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std + + return out.view(shape) diff --git a/comfy/ldm/modules/encoders/modules.py b/comfy/ldm/modules/encoders/modules.py index 4edd5496b..bc9fde638 100644 --- a/comfy/ldm/modules/encoders/modules.py +++ b/comfy/ldm/modules/encoders/modules.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from . import kornia_functions from torch.utils.checkpoint import checkpoint from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel @@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module): c = batch[key][:, None] if self.ucg_rate > 0. and not disable_dropout: mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c @@ -57,18 +58,20 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, + freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False @@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): "pooled", "hidden" ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() @@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False @@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder): return self(text) +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0. + ): + super().__init__() + from clip import load as load_clip + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + # x = kornia_functions.geometry_resize(x, (224, 224), + # interpolation='bicubic', align_corners=True, + # antialias=self.antialias) + x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia_functions.enhance_normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = [ - #"pooled", + # "pooled", "last", "penultimate" ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last"): super().__init__() @@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): return self(text) +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="pooled", antialias=True, ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + # x = kornia.geometry.resize(x, (224, 224), + # interpolation='bicubic', align_corners=True, + # antialias=self.antialias) + x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia_functions.enhance_normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + + class FrozenCLIPT5Encoder(AbstractEncoder): def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") def encode(self, text): return self(text) @@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py new file mode 100644 index 000000000..f99e7920a --- /dev/null +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -0,0 +1,35 @@ +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ldm.modules.diffusionmodules.openaimodel import Timestep +import torch + +class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): + def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): + super().__init__(*args, **kwargs) + if clip_stats_path is None: + clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) + else: + clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") + self.register_buffer("data_mean", clip_mean[None, :], persistent=False) + self.register_buffer("data_std", clip_std[None, :], persistent=False) + self.time_embed = Timestep(timestep_dim) + + def scale(self, x): + # re-normalize to centered mean and unit variance + x = (x - self.data_mean) * 1. / self.data_std + return x + + def unscale(self, x): + # back to original data stats + x = (x * self.data_std) + self.data_mean + return x + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + x = self.scale(x) + z = self.q_sample(x, noise_level) + z = self.unscale(z) + noise_level = self.time_embed(noise_level) + return z, noise_level diff --git a/comfy/samplers.py b/comfy/samplers.py index 15e78bbd7..ddec99007 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'strength' in cond[1]: strength = cond[1]['strength'] + adm_cond = None + if 'adm' in cond[1]: + adm_cond = cond[1]['adm'] + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] mult = torch.ones_like(input_x) * strength @@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con cropped.append(cr) conditionning['c_concat'] = torch.cat(cropped, dim=1) + if adm_cond is not None: + conditionning['c_adm'] = adm_cond + control = None if 'control' in cond[1]: control = cond[1]['control'] @@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False + if 'c_adm' in c1: + if c1['c_adm'].shape != c2['c_adm'].shape: + return False return True def can_concat_cond(c1, c2): @@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def cond_cat(c_list): c_crossattn = [] c_concat = [] + c_adm = [] for x in c_list: if 'c_crossattn' in x: c_crossattn.append(x['c_crossattn']) if 'c_concat' in x: c_concat.append(x['c_concat']) + if 'c_adm' in x: + c_adm.append(x['c_adm']) out = {} if len(c_crossattn) > 0: out['c_crossattn'] = [torch.cat(c_crossattn)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] + if len(c_adm) > 0: + out['c_adm'] = torch.cat(c_adm) return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): @@ -327,6 +342,30 @@ def apply_control_net_to_equal_area(conds, uncond): n['control'] = cond_cnets[x] uncond[temp[1]] = [o[0], n] +def encode_adm(noise_augmentor, conds, batch_size, device): + for t in range(len(conds)): + x = conds[t] + if 'adm' in x[1]: + adm_inputs = [] + weights = [] + adm_in = x[1]["adm"] + for adm_c in adm_in: + adm_cond = adm_c[0].image_embeds + weight = adm_c[1] + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + adm_inputs.append(adm_out) + + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: Apply Noise to Embedding Mix + else: + adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) + x[1] = x[1].copy() + x[1]["adm"] = torch.cat([adm_out] * batch_size) + + return conds + class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", @@ -422,10 +461,14 @@ class KSampler: else: precision_scope = contextlib.nullcontext + if hasattr(self.model, 'noise_augmentor'): #unclip + positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) + negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None - if hasattr(self.model, 'concat_keys'): + if hasattr(self.model, 'concat_keys'): #inpaint cond_concat = [] for ck in self.model.concat_keys: if denoise_mask is not None: diff --git a/comfy/sd.py b/comfy/sd.py index 2a38ceb15..2d7ff5ab0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,20 +12,7 @@ from .cldm import cldm from .t2i_adapter import adapter from . import utils - -def load_torch_file(ckpt): - if ckpt.lower().endswith(".safetensors"): - import safetensors.torch - sd = safetensors.torch.load_file(ckpt, device="cpu") - else: - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - return sd +from . import clip_vision def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -53,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): if x in sd: sd[keys_to_replace[x]] = sd.pop(x) - resblock_to_replace = { - "ln_1": "layer_norm1", - "ln_2": "layer_norm2", - "mlp.c_fc": "mlp.fc1", - "mlp.c_proj": "mlp.fc2", - "attn.out_proj": "self_attn.out_proj", - } - - for resblock in range(24): - for x in resblock_to_replace: - for y in ["weight", "bias"]: - k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y) - k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y) - if k in sd: - sd[k_to] = sd.pop(k) - - for y in ["weight", "bias"]: - k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y) - if k_from in sd: - weights = sd.pop(k_from) - for x in range(3): - p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] - k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y) - sd[k_to] = weights[1024*x:1024*(x + 1)] + sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) for x in load_state_dict_to: x.load_state_dict(sd, strict=False) @@ -123,7 +87,7 @@ LORA_UNET_MAP_RESNET = { } def load_lora(path, to_load): - lora = load_torch_file(path) + lora = utils.load_torch_file(path) patch_dict = {} loaded_keys = set() for x in to_load: @@ -599,7 +563,7 @@ class ControlNet: return out def load_controlnet(ckpt_path, model=None): - controlnet_data = load_torch_file(ckpt_path) + controlnet_data = utils.load_torch_file(ckpt_path) pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth = False sd2 = False @@ -793,7 +757,7 @@ class StyleModel: def load_style_model(ckpt_path): - model_data = load_torch_file(ckpt_path) + model_data = utils.load_torch_file(ckpt_path) keys = model_data.keys() if "style_embedding" in keys: model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) @@ -804,7 +768,7 @@ def load_style_model(ckpt_path): def load_clip(ckpt_path, embedding_directory=None): - clip_data = load_torch_file(ckpt_path) + clip_data = utils.load_torch_file(ckpt_path) config = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' @@ -847,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e load_state_dict_to = [w] model = instantiate_from_config(config["model"]) - sd = load_torch_file(ckpt_path) + sd = utils.load_torch_file(ckpt_path) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: @@ -856,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e return (ModelPatcher(model), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): - sd = load_torch_file(ckpt_path) +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): + sd = utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None + clipvision = None vae = None fp16 = model_management.should_use_fp16() @@ -884,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] + clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" + noise_aug_config = None + if clipvision_key in sd_keys: + size = sd[clipvision_key].shape[1] + + if output_clipvision: + clipvision = clip_vision.load_clipvision_from_sd(sd) + + noise_aug_key = "noise_augmentor.betas" + if noise_aug_key in sd_keys: + noise_aug_config = {} + params = {} + noise_schedule_config = {} + noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] + noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" + params["noise_schedule_config"] = noise_schedule_config + noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" + if size == 1280: #h + params["timestep_dim"] = 1024 + elif size == 1024: #l + params["timestep_dim"] = 768 + noise_aug_config['params'] = params + sd_config = { "linear_start": 0.00085, "linear_end": 0.012, @@ -932,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} - if unet_config["in_channels"] > 4: #inpainting model + if noise_aug_config is not None: #SD2.x unclip model + sd_config["noise_aug_config"] = noise_aug_config + sd_config["image_size"] = 96 + sd_config["embedding_dropout"] = 0.25 + sd_config["conditioning_key"] = 'crossattn-adm' + model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" + elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" @@ -944,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e else: unet_config["num_heads"] = 8 #SD1.x + unclip = 'model.diffusion_model.label_emb.0.0.weight' + if unclip in sd_keys: + unet_config["num_classes"] = "sequential" + unet_config["adm_in_channels"] = sd[unclip].shape[1] + if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" out = sd[k] @@ -956,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e if fp16: model = model.half() - return (ModelPatcher(model), clip, vae) + return (ModelPatcher(model), clip, vae, clipvision) diff --git a/comfy/utils.py b/comfy/utils.py index 798ac1c45..0380b91dd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,47 @@ import torch +def load_torch_file(ckpt): + if ckpt.lower().endswith(".safetensors"): + import safetensors.torch + sd = safetensors.torch.load_file(ckpt, device="cpu") + else: + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd + return sd + +def transformers_convert(sd, prefix_from, prefix_to, number): + resblock_to_replace = { + "ln_1": "layer_norm1", + "ln_2": "layer_norm2", + "mlp.c_fc": "mlp.fc1", + "mlp.c_proj": "mlp.fc2", + "attn.out_proj": "self_attn.out_proj", + } + + for resblock in range(number): + for x in resblock_to_replace: + for y in ["weight", "bias"]: + k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) + k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + if k in sd: + sd[k_to] = sd.pop(k) + + for y in ["weight", "bias"]: + k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) + if k_from in sd: + weights = sd.pop(k_from) + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) + sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return sd + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] diff --git a/comfy_extras/clip_vision.py b/comfy_extras/clip_vision.py deleted file mode 100644 index 58d79a83e..000000000 --- a/comfy_extras/clip_vision.py +++ /dev/null @@ -1,32 +0,0 @@ -from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor -from comfy.sd import load_torch_file -import os - -class ClipVisionModel(): - def __init__(self): - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json") - config = CLIPVisionConfig.from_json_file(json_config) - self.model = CLIPVisionModel(config) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[ 0.48145466,0.4578275,0.40821073], - image_std=[0.26862954,0.26130258,0.27577711], - resample=3, #bicubic - size=224) - - def load_sd(self, sd): - self.model.load_state_dict(sd, strict=False) - - def encode_image(self, image): - inputs = self.processor(images=[image[0]], return_tensors="pt") - outputs = self.model(**inputs) - return outputs - -def load(ckpt_path): - clip_data = load_torch_file(ckpt_path) - clip = ClipVisionModel() - clip.load_sd(clip_data) - return clip diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index b79b78511..6a7d0e516 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -1,6 +1,5 @@ import os from comfy_extras.chainner_models import model_loading -from comfy.sd import load_torch_file import model_management import torch import comfy.utils @@ -18,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path) out = model_loading.load_state_dict(sd).eval() return (out, ) diff --git a/nodes.py b/nodes.py index e69832c56..1555c19c9 100644 --- a/nodes.py +++ b/nodes.py @@ -18,7 +18,7 @@ import comfy.samplers import comfy.sd import comfy.utils -import comfy_extras.clip_vision +import comfy.clip_vision import model_management import importlib @@ -219,6 +219,21 @@ class CheckpointLoaderSimple: out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out +class unCLIPCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") + FUNCTION = "load_checkpoint" + + CATEGORY = "_for_testing/unclip" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return out + class CLIPSetLastLayer: @classmethod def INPUT_TYPES(s): @@ -370,7 +385,7 @@ class CLIPVisionLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip_vision", clip_name) - clip_vision = comfy_extras.clip_vision.load(clip_path) + clip_vision = comfy.clip_vision.load(clip_path) return (clip_vision,) class CLIPVisionEncode: @@ -382,7 +397,7 @@ class CLIPVisionEncode: RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" - CATEGORY = "conditioning/style_model" + CATEGORY = "conditioning" def encode(self, clip_vision, image): output = clip_vision.encode_image(image) @@ -424,6 +439,32 @@ class StyleModelApply: c.append(n) return (c, ) +class unCLIPConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_adm" + + CATEGORY = "_for_testing/unclip" + + def apply_adm(self, conditioning, clip_vision_output, strength): + c = [] + for t in conditioning: + o = t[1].copy() + x = (clip_vision_output, strength) + if "adm" in o: + o["adm"] = o["adm"][:] + [x] + else: + o["adm"] = [x] + n = [t[0], o] + c.append(n) + return (c, ) + + class EmptyLatentImage: def __init__(self, device="cpu"): self.device = device @@ -1025,6 +1066,7 @@ NODE_CLASS_MAPPINGS = { "CLIPLoader": CLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply, + "unCLIPConditioning": unCLIPConditioning, "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, @@ -1033,6 +1075,7 @@ NODE_CLASS_MAPPINGS = { "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, "TomePatchModel": TomePatchModel, + "unCLIPCheckpointLoader": unCLIPCheckpointLoader, } def load_custom_node(module_path): From 66f1f576151ba6da13ebd34a540bc1f7301fb52a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 2 Apr 2023 01:54:44 -0400 Subject: [PATCH 17/83] Add --extra-model-paths-config to --help. --- main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/main.py b/main.py index c9809137a..9c9fb7613 100644 --- a/main.py +++ b/main.py @@ -11,9 +11,14 @@ if os.name == "nt": if __name__ == "__main__": if '--help' in sys.argv: + print() print("Valid Command line Arguments:") print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.") print("\t--port 8188\t\t\tSet the listen port.") + print() + print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") + print() + print() print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") From 5aefd6cdf3504119e10132f44dd5863581dc337d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 15:16:08 +0100 Subject: [PATCH 18/83] Support numeric settings, tooltip, extra attrs --- web/scripts/ui.js | 86 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 28 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index c27fbf986..679f10b20 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -198,7 +198,7 @@ class ComfySettingsDialog extends ComfyDialog { localStorage[settingId] = JSON.stringify(value); } - addSetting({ id, name, type, defaultValue, onChange }) { + addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) { if (!id) { throw new Error("Settings must have an ID"); } @@ -225,42 +225,72 @@ class ComfySettingsDialog extends ComfyDialog { value = v; }; + let element; + if (typeof type === "function") { - return type(name, setter, value); + element = type(name, setter, value, attrs); + } else { + switch (type) { + case "boolean": + element = $el("div", [ + $el("label", { textContent: name || id }, [ + $el("input", { + type: "checkbox", + checked: !!value, + oninput: (e) => { + setter(e.target.checked); + }, + ...attrs + }), + ]), + ]); + break; + case "number": + element = $el("div", [ + $el("label", { textContent: name || id }, [ + $el("input", { + type, + value, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs + }), + ]), + ]); + break; + default: + console.warn("Unsupported setting type, defaulting to text"); + element = $el("div", [ + $el("label", { textContent: name || id }, [ + $el("input", { + value, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs + }), + ]), + ]); + break; + } + } + if(tooltip) { + element.title = tooltip; } - switch (type) { - case "boolean": - return $el("div", [ - $el("label", { textContent: name || id }, [ - $el("input", { - type: "checkbox", - checked: !!value, - oninput: (e) => { - setter(e.target.checked); - }, - }), - ]), - ]); - default: - console.warn("Unsupported setting type, defaulting to text"); - return $el("div", [ - $el("label", { textContent: name || id }, [ - $el("input", { - value, - oninput: (e) => { - setter(e.target.value); - }, - }), - ]), - ]); - } + return element; }, }); } show() { super.show(); + Object.assign(this.textElement.style, { + display: "flex", + flexDirection: "column", + gap: "10px" + }); this.textElement.replaceChildren(...this.settings.map((s) => s.render())); } } From 940893f92c9ba0d7ae98713e18404c623afe4789 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 2 Apr 2023 10:27:01 -0400 Subject: [PATCH 19/83] Update the example_node.py.example with RETURN_NAMES. --- custom_nodes/example_node.py.example | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 1bb1a5a37..fb8172648 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -11,6 +11,8 @@ class Example: ---------- RETURN_TYPES (`tuple`): The type of each element in the output tulple. + RETURN_NAMES (`tuple`): + Optional: The name of each output in the output tulple. FUNCTION (`str`): The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() OUTPUT_NODE ([`bool`]): @@ -61,6 +63,8 @@ class Example: } RETURN_TYPES = ("IMAGE",) + #RETURN_NAMES = ("image_output_name",) + FUNCTION = "test" #OUTPUT_NODE = False From d027ff121c904f5d21b5d9fd8607fcdb2b166ec3 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 15:33:34 +0100 Subject: [PATCH 20/83] Snap to grid --- web/extensions/core/snapToGrid.js | 86 +++++++++++++++++++++++++++++++ web/scripts/app.js | 8 ++- 2 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 web/extensions/core/snapToGrid.js diff --git a/web/extensions/core/snapToGrid.js b/web/extensions/core/snapToGrid.js new file mode 100644 index 000000000..80e836a0b --- /dev/null +++ b/web/extensions/core/snapToGrid.js @@ -0,0 +1,86 @@ +import { app } from "/scripts/app.js"; + +// Shift + drag/resize to snap to grid + +app.registerExtension({ + name: "Comfy.SnapToGrid", + init() { + // Add setting to control grid size + app.ui.settings.addSetting({ + id: "Comfy.SnapToGrid.GridSize", + name: "Grid Size", + type: "number", + attrs: { + min: 1, + max: 500, + }, + tooltip: + "When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.", + defaultValue: LiteGraph.CANVAS_GRID_SIZE, + onChange(value) { + LiteGraph.CANVAS_GRID_SIZE = +value; + }, + }); + + // After moving a node, if the shift key is down align it to grid + const onNodeMoved = app.canvas.onNodeMoved; + app.canvas.onNodeMoved = function (node) { + const r = onNodeMoved?.apply(this, arguments); + + if (app.shiftDown) { + node.alignToGrid(); + } + + return r; + }; + + // When a node is added, add a resize handler to it so we can fix align the size with the grid + const onNodeAdded = app.graph.onNodeAdded; + app.graph.onNodeAdded = function (node) { + const onResize = node.onResize; + node.onResize = function () { + if(app.shiftDown) { + const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE); + const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE); + node.size[0] = w; + node.size[1] = h; + } + return onResize?.apply(this, arguments); + }; + return onNodeAdded?.apply(this, arguments); + }; + + // Draw a preview of where the node will go if holding shift + const origDrawNode = LGraphCanvas.prototype.drawNode; + LGraphCanvas.prototype.drawNode = function (node, ctx) { + if (app.shiftDown && node === this.node_dragged) { + const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE); + const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE); + + const shiftX = x - node.pos[0]; + let shiftY = y - node.pos[1]; + + let w, h; + if (node.flags.collapsed) { + w = node._collapsed_width; + h = LiteGraph.NODE_TITLE_HEIGHT; + shiftY -= LiteGraph.NODE_TITLE_HEIGHT; + } else { + w = node.size[0]; + h = node.size[1]; + let titleMode = node.constructor.title_mode; + if (titleMode !== LiteGraph.TRANSPARENT_TITLE && titleMode !== LiteGraph.NO_TITLE) { + h += LiteGraph.NODE_TITLE_HEIGHT; + shiftY -= LiteGraph.NODE_TITLE_HEIGHT; + } + } + const f = ctx.fillStyle; + ctx.fillStyle = "rgba(100, 100, 100, 0.5)"; + ctx.fillRect(shiftX, shiftY, w, h); + ctx.fillStyle = f; + } + + return origDrawNode.apply(this, arguments); + }; + }, +}); diff --git a/web/scripts/app.js b/web/scripts/app.js index 5af6d5fc0..6f8ac067b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -18,6 +18,7 @@ class ComfyApp { this.ui = new ComfyUI(this); this.extensions = []; this.nodeOutputs = {}; + this.shiftDown = false; } /** @@ -538,7 +539,7 @@ class ComfyApp { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; - } + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; @@ -637,11 +638,16 @@ class ComfyApp { #addKeyboardHandler() { window.addEventListener("keydown", (e) => { + this.shiftDown = e.shiftKey; + // Queue prompt using ctrl or command + enter if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) { this.queuePrompt(e.shiftKey ? -1 : 0); } }); + window.addEventListener("keyup", (e) => { + this.shiftDown = e.shiftKey; + }); } /** From 26dc8e3056c18bf08e83cfc868665cea25a90868 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 15:36:27 +0100 Subject: [PATCH 21/83] formatting --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 6f8ac067b..8612d5a34 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -539,7 +539,7 @@ class ComfyApp { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; - } + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; From 1917064b56b6c2a7206e6abfca219d742acee9ff Mon Sep 17 00:00:00 2001 From: Tomoaki Hayasaka Date: Sun, 2 Apr 2023 21:43:40 +0900 Subject: [PATCH 22/83] Fix "extra filename replacements in SaveImage is not done when prefix is supplied by Primitive". --- web/extensions/core/widgetInputs.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 7e6688261..865af7763 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -20,7 +20,7 @@ function hideWidget(node, widget, suffix = "") { if (link == null) { return undefined; } - return widget.value; + return widget.origSerializeValue ? widget.origSerializeValue() : widget.value; }; // Hide any linked widgets, e.g. seed+randomize From 519890a5cc5c09d1eabf1fbb355863db0deae17e Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 15:57:12 +0100 Subject: [PATCH 23/83] Adds middle click for default node creation Enable two useful properties --- web/extensions/core/slotDefaults.js | 21 +++++++++++++++++++++ web/scripts/app.js | 3 +++ 2 files changed, 24 insertions(+) create mode 100644 web/extensions/core/slotDefaults.js diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js new file mode 100644 index 000000000..0b6a0a150 --- /dev/null +++ b/web/extensions/core/slotDefaults.js @@ -0,0 +1,21 @@ +import { app } from "/scripts/app.js"; + +// Adds defaults for quickly adding nodes with middle click on the input/output + +app.registerExtension({ + name: "Comfy.SlotDefaults", + init() { + LiteGraph.middle_click_slot_add_default_node = true; + LiteGraph.slot_types_default_in = { + MODEL: "CheckpointLoaderSimple", + LATENT: "EmptyLatentImage", + VAE: "VAELoader", + }; + + LiteGraph.slot_types_default_out = { + LATENT: "VAEDecode", + IMAGE: "SaveImage", + CLIP: "CLIPTextEncode", + }; + }, +}); diff --git a/web/scripts/app.js b/web/scripts/app.js index 5af6d5fc0..c216d2614 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -676,6 +676,9 @@ class ComfyApp { const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph)); this.ctx = canvasEl.getContext("2d"); + LiteGraph.release_link_on_empty_shows_menu = true; + LiteGraph.alt_drag_do_clone_nodes = true; + this.graph.start(); function resizeCanvas() { From 8a0a85e0fa68b6e400d508b61d97621ebb9bff29 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 19:03:34 +0100 Subject: [PATCH 24/83] Added filter input to combos --- web/extensions/core/contextMenuFilter.js | 66 ++++++++++++++++++++++ web/extensions/core/invertMenuScrolling.js | 2 +- 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 web/extensions/core/contextMenuFilter.js diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js new file mode 100644 index 000000000..4867a30a6 --- /dev/null +++ b/web/extensions/core/contextMenuFilter.js @@ -0,0 +1,66 @@ +import { app } from "/scripts/app.js"; + +// Adds filtering to context menus + +const id = "Comfy.ContextMenuFilter"; +app.registerExtension({ + name: id, + init() { + const ctxMenu = LiteGraph.ContextMenu; + LiteGraph.ContextMenu = function (values, options) { + const ctx = ctxMenu.call(this, values, options); + + // If we are a dark menu (only used for combo boxes) then add a filter input + if (options?.className === "dark" && values?.length > 10) { + const filter = document.createElement("input"); + Object.assign(filter.style, { + width: "calc(100% - 10px)", + border: "0", + boxSizing: "border-box", + background: "#333", + border: "1px solid #999", + margin: "0 0 5px 5px", + color: "#fff", + }); + filter.placeholder = "Filter list"; + this.root.prepend(filter); + + filter.addEventListener("input", () => { + // Hide all items that dont match our filter + const term = filter.value.toLocaleLowerCase(); + const items = this.root.querySelectorAll(".litemenu-entry"); + for (const item of items) { + item.style.display = !term || item.textContent.toLocaleLowerCase().includes(term) ? "block" : "none"; + } + + // If we have an event then we can try and position the list under the source + if (options.event) { + let top = options.event.clientY - 10; + + const bodyRect = document.body.getBoundingClientRect(); + const rootRect = this.root.getBoundingClientRect(); + if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { + top = Math.max(0, bodyRect.height - rootRect.height - 10); + } + + this.root.style.top = top + "px"; + } + }); + + requestAnimationFrame(() => { + // Focus the filter box when opening + filter.focus(); + + // If the top is off screen then shift the element + if (parseInt(this.root.style.top) < 0) { + this.root.style.top = 0; + } + }); + } + + return ctx; + }; + + LiteGraph.ContextMenu.prototype = ctxMenu.prototype; + }, +}); diff --git a/web/extensions/core/invertMenuScrolling.js b/web/extensions/core/invertMenuScrolling.js index 34523d55c..f900fccf4 100644 --- a/web/extensions/core/invertMenuScrolling.js +++ b/web/extensions/core/invertMenuScrolling.js @@ -3,10 +3,10 @@ import { app } from "/scripts/app.js"; // Inverts the scrolling of context menus const id = "Comfy.InvertMenuScrolling"; -const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, init() { + const ctxMenu = LiteGraph.ContextMenu; const replace = () => { LiteGraph.ContextMenu = function (values, options) { options = options || {}; From 04234152c14e58cf1e09abad442b0586c5bf2339 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 19:12:00 +0100 Subject: [PATCH 25/83] Add support for multiselect --- web/extensions/core/snapToGrid.js | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/web/extensions/core/snapToGrid.js b/web/extensions/core/snapToGrid.js index 80e836a0b..20b245e18 100644 --- a/web/extensions/core/snapToGrid.js +++ b/web/extensions/core/snapToGrid.js @@ -1,6 +1,6 @@ import { app } from "/scripts/app.js"; -// Shift + drag/resize to snap to grid +// Shift + drag/resize to snap to grid app.registerExtension({ name: "Comfy.SnapToGrid", @@ -28,32 +28,35 @@ app.registerExtension({ const r = onNodeMoved?.apply(this, arguments); if (app.shiftDown) { - node.alignToGrid(); + // Ensure all selected nodes are realigned + for (const id in this.selected_nodes) { + this.selected_nodes[id].alignToGrid(); + } } return r; }; - // When a node is added, add a resize handler to it so we can fix align the size with the grid + // When a node is added, add a resize handler to it so we can fix align the size with the grid const onNodeAdded = app.graph.onNodeAdded; app.graph.onNodeAdded = function (node) { const onResize = node.onResize; node.onResize = function () { - if(app.shiftDown) { - const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE); - const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE); - node.size[0] = w; - node.size[1] = h; - } + if (app.shiftDown) { + const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE); + const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE); + node.size[0] = w; + node.size[1] = h; + } return onResize?.apply(this, arguments); }; return onNodeAdded?.apply(this, arguments); }; - // Draw a preview of where the node will go if holding shift + // Draw a preview of where the node will go if holding shift and the node is selected const origDrawNode = LGraphCanvas.prototype.drawNode; LGraphCanvas.prototype.drawNode = function (node, ctx) { - if (app.shiftDown && node === this.node_dragged) { + if (app.shiftDown && this.node_dragged && node.id in this.selected_nodes) { const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE); const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE); From 74893be1ce6b8350d1eafea823450e9a002380e8 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 21:01:39 +0100 Subject: [PATCH 26/83] Added keyboard navigation + selection --- web/extensions/core/contextMenuFilter.js | 64 +++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 4867a30a6..ced5a0a34 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -25,13 +25,73 @@ app.registerExtension({ filter.placeholder = "Filter list"; this.root.prepend(filter); + let selectedIndex = 0; + let items = this.root.querySelectorAll(".litemenu-entry"); + let itemCount = items.length; + let selectedItem; + + // Apply highlighting to the selected item + function updateSelected() { + if (selectedItem) { + selectedItem.style.setProperty("background-color", ""); + selectedItem.style.setProperty("color", ""); + } + selectedItem = items[selectedIndex]; + if (selectedItem) { + selectedItem.style.setProperty("background-color", "#ccc", "important"); + selectedItem.style.setProperty("color", "#000", "important"); + } + } + + updateSelected(); + + // Arrow up/down to select items + filter.addEventListener("keydown", (e) => { + if (e.key === "ArrowUp") { + if (selectedIndex === 0) { + selectedIndex = itemCount - 1; + } else { + selectedIndex--; + } + updateSelected(); + e.preventDefault(); + } else if (e.key === "ArrowDown") { + if (selectedIndex === itemCount - 1) { + selectedIndex = 0; + } else { + selectedIndex++; + } + updateSelected(); + e.preventDefault(); + } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) { + selectedItem.click(); + } + }); + filter.addEventListener("input", () => { // Hide all items that dont match our filter const term = filter.value.toLocaleLowerCase(); - const items = this.root.querySelectorAll(".litemenu-entry"); + items = this.root.querySelectorAll(".litemenu-entry"); + // When filtering recompute which items are visible for arrow up/down + // Try and maintain selection + let visibleItems = []; for (const item of items) { - item.style.display = !term || item.textContent.toLocaleLowerCase().includes(term) ? "block" : "none"; + const visible = !term || item.textContent.toLocaleLowerCase().includes(term); + if (visible) { + item.style.display = "block"; + if (item === selectedItem) { + selectedIndex = visibleItems.length; + } + visibleItems.push(item); + } else { + item.style.display = "none"; + if (item === selectedItem) { + selectedIndex = 0; + } + } } + items = visibleItems; + updateSelected(); // If we have an event then we can try and position the list under the source if (options.event) { From 32fd39b4245a50757fd8257ea199f74ade348b9a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 21:02:40 +0100 Subject: [PATCH 27/83] Update comment --- web/extensions/core/contextMenuFilter.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index ced5a0a34..8aac84d5b 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -1,6 +1,6 @@ import { app } from "/scripts/app.js"; -// Adds filtering to context menus +// Adds filtering to combo context menus const id = "Comfy.ContextMenuFilter"; app.registerExtension({ From 1a322ca67a29cedd8b33da85fbae0c27a99cd24b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 2 Apr 2023 21:37:24 +0100 Subject: [PATCH 28/83] Fix scaled position --- web/extensions/core/contextMenuFilter.js | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 8aac84d5b..fa4cb2422 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -111,9 +111,13 @@ app.registerExtension({ // Focus the filter box when opening filter.focus(); - // If the top is off screen then shift the element - if (parseInt(this.root.style.top) < 0) { - this.root.style.top = 0; + const rect = this.root.getBoundingClientRect(); + + // If the top is off screen then shift the element with scaling applied + if (rect.top < 0) { + const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; + const shift = (this.root.clientHeight * scale) / 2; + this.root.style.top = -shift + "px"; } }); } From 4c7a9dbcb66d3a53764d4725f92f7c116bcb4821 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 2 Apr 2023 18:44:27 -0400 Subject: [PATCH 29/83] adds Blend, Blur, Dither, Sharpen nodes --- comfy_extras/nodes_post_processing.py | 215 ++++++++++++++++++++++++++ nodes.py | 3 +- 2 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_post_processing.py diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py new file mode 100644 index 000000000..3f3bddd75 --- /dev/null +++ b/comfy_extras/nodes_post_processing.py @@ -0,0 +1,215 @@ +import torch +import torch.nn.functional as F + + +class Blend: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "blend_factor": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01 + }), + "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blend_images" + + CATEGORY = "postprocessing" + + def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor + blended_image = torch.clamp(blended_image, 0, 1) + return (blended_image,) + + def blend_mode(self, img1, img2, mode): + if mode == "normal": + return img2 + elif mode == "multiply": + return img1 * img2 + elif mode == "screen": + return 1 - (1 - img1) * (1 - img2) + elif mode == "overlay": + return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) + elif mode == "soft_light": + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + else: + raise ValueError(f"Unsupported blend mode: {mode}") + + def g(self, x): + return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + +class Blur: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "blur_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "sigma": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blur" + + CATEGORY = "postprocessing" + + def gaussian_kernel(self, kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + + def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + if blur_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + + kernel_size = blur_radius * 2 + 1 + kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + + image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + blurred = blurred.permute(0, 2, 3, 1) + + return (blurred,) + +class Dither: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "bits": ("INT", { + "default": 4, + "min": 1, + "max": 8, + "step": 1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "dither" + + CATEGORY = "postprocessing" + + def dither(self, image: torch.Tensor, bits: int): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255) + height, width, _ = img.shape + + scale = 255 / (2**bits - 1) + + for y in range(height): + for x in range(width): + old_pixel = img[y, x].clone() + new_pixel = torch.round(old_pixel / scale) * scale + img[y, x] = new_pixel + + quant_error = old_pixel - new_pixel + + if x + 1 < width: + img[y, x + 1] += quant_error * 7 / 16 + if y + 1 < height: + if x - 1 >= 0: + img[y + 1, x - 1] += quant_error * 3 / 16 + img[y + 1, x] += quant_error * 5 / 16 + if x + 1 < width: + img[y + 1, x + 1] += quant_error * 1 / 16 + + dithered = img / 255 + tensor = dithered.unsqueeze(0) + result[b] = tensor + + return (result,) + +class Sharpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "sharpen_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 5.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "sharpen" + + CATEGORY = "postprocessing" + + def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + if sharpen_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + + kernel_size = sharpen_radius * 2 + 1 + kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + center = kernel_size // 2 + kernel[center, center] = kernel_size**2 + kernel *= alpha + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) + + tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + sharpened = sharpened.permute(0, 2, 3, 1) + + result = torch.clamp(sharpened, 0, 1) + + return (result,) + +NODE_CLASS_MAPPINGS = { + "Blend": Blend, + "Blur": Blur, + "Dither": Dither, + "Sharpen": Sharpen, +} diff --git a/nodes.py b/nodes.py index 963ff32a0..a93f04108 100644 --- a/nodes.py +++ b/nodes.py @@ -1112,4 +1112,5 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) \ No newline at end of file + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) From 028e1f7ad2a50efea8391ea54b606cf865d788db Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 3 Apr 2023 08:11:44 +0100 Subject: [PATCH 30/83] Fix scaled position when filtering Add esc to close --- web/extensions/core/contextMenuFilter.js | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index fa4cb2422..51e66f924 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -43,6 +43,17 @@ app.registerExtension({ } } + const positionList = () => { + const rect = this.root.getBoundingClientRect(); + + // If the top is off screen then shift the element with scaling applied + if (rect.top < 0) { + const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; + const shift = (this.root.clientHeight * scale) / 2; + this.root.style.top = -shift + "px"; + } + } + updateSelected(); // Arrow up/down to select items @@ -65,6 +76,8 @@ app.registerExtension({ e.preventDefault(); } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) { selectedItem.click(); + } else if(e.key === "Escape") { + this.close(); } }); @@ -104,6 +117,7 @@ app.registerExtension({ } this.root.style.top = top + "px"; + positionList(); } }); @@ -111,14 +125,7 @@ app.registerExtension({ // Focus the filter box when opening filter.focus(); - const rect = this.root.getBoundingClientRect(); - - // If the top is off screen then shift the element with scaling applied - if (rect.top < 0) { - const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; - const shift = (this.root.clientHeight * scale) / 2; - this.root.style.top = -shift + "px"; - } + positionList(); }); } From fa2febc0624678362cc758d316bb59afce9c8f06 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 3 Apr 2023 09:52:04 -0400 Subject: [PATCH 31/83] blend supports any size, dither -> quantize --- comfy_extras/nodes_post_processing.py | 74 ++++++++++++++++----------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3f3bddd75..322f3ca89 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,5 +1,7 @@ +import numpy as np import torch import torch.nn.functional as F +from PIL import Image class Blend: @@ -28,6 +30,9 @@ class Blend: CATEGORY = "postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + if image1.shape != image2.shape: + image2 = self.crop_and_resize(image2, image1.shape) + blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = torch.clamp(blended_image, 0, 1) @@ -50,6 +55,29 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + def crop_and_resize(self, img: torch.Tensor, target_shape: tuple): + batch_size, img_h, img_w, img_c = img.shape + _, target_h, target_w, _ = target_shape + img_aspect_ratio = img_w / img_h + target_aspect_ratio = target_w / target_h + + # Crop center of the image to the target aspect ratio + if img_aspect_ratio > target_aspect_ratio: + new_width = int(img_h * target_aspect_ratio) + left = (img_w - new_width) // 2 + img = img[:, :, left:left + new_width, :] + else: + new_height = int(img_w / target_aspect_ratio) + top = (img_h - new_height) // 2 + img = img[:, top:top + new_height, :, :] + + # Resize to target size + img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False) + img = img.permute(0, 2, 3, 1) + + return img + class Blur: def __init__(self): pass @@ -100,7 +128,7 @@ class Blur: return (blurred,) -class Dither: +class Quantize: def __init__(self): pass @@ -109,51 +137,37 @@ class Dither: return { "required": { "image": ("IMAGE",), - "bits": ("INT", { - "default": 4, + "colors": ("INT", { + "default": 256, "min": 1, - "max": 8, + "max": 256, "step": 1 }), + "dither": (["none", "floyd-steinberg"],), }, } RETURN_TYPES = ("IMAGE",) - FUNCTION = "dither" + FUNCTION = "quantize" CATEGORY = "postprocessing" - def dither(self, image: torch.Tensor, bits: int): + def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) + dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE + for b in range(batch_size): tensor_image = image[b] - img = (tensor_image * 255) - height, width, _ = img.shape + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') - scale = 255 / (2**bits - 1) + palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) - for y in range(height): - for x in range(width): - old_pixel = img[y, x].clone() - new_pixel = torch.round(old_pixel / scale) * scale - img[y, x] = new_pixel - - quant_error = old_pixel - new_pixel - - if x + 1 < width: - img[y, x + 1] += quant_error * 7 / 16 - if y + 1 < height: - if x - 1 >= 0: - img[y + 1, x - 1] += quant_error * 3 / 16 - img[y + 1, x] += quant_error * 5 / 16 - if x + 1 < width: - img[y + 1, x + 1] += quant_error * 1 / 16 - - dithered = img / 255 - tensor = dithered.unsqueeze(0) - result[b] = tensor + quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 + result[b] = quantized_array return (result,) @@ -210,6 +224,6 @@ class Sharpen: NODE_CLASS_MAPPINGS = { "Blend": Blend, "Blur": Blur, - "Dither": Dither, + "Quantize": Quantize, "Sharpen": Sharpen, } From f50b1fec695cecc8f7c87ce1f39db3f6b49bb3a1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 3 Apr 2023 13:50:29 -0400 Subject: [PATCH 32/83] Add noise augmentation setting to unCLIPConditioning. --- comfy/samplers.py | 16 +++++++++++++--- nodes.py | 5 +++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index ddec99007..59dbab53d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -348,17 +348,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device): if 'adm' in x[1]: adm_inputs = [] weights = [] + noise_aug = [] adm_in = x[1]["adm"] for adm_c in adm_in: adm_cond = adm_c[0].image_embeds weight = adm_c[1] - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) + noise_augment = adm_c[2] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) + noise_aug.append(noise_augment) adm_inputs.append(adm_out) - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: Apply Noise to Embedding Mix + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + print(noise_level) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) else: adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) x[1] = x[1].copy() diff --git a/nodes.py b/nodes.py index 963ff32a0..ffbba9f94 100644 --- a/nodes.py +++ b/nodes.py @@ -445,17 +445,18 @@ class unCLIPConditioning: return {"required": {"conditioning": ("CONDITIONING", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" CATEGORY = "_for_testing/unclip" - def apply_adm(self, conditioning, clip_vision_output, strength): + def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): c = [] for t in conditioning: o = t[1].copy() - x = (clip_vision_output, strength) + x = (clip_vision_output, strength, noise_augmentation) if "adm" in o: o["adm"] = o["adm"][:] + [x] else: From 4e437582365cd1b9f8261be9405180feef382357 Mon Sep 17 00:00:00 2001 From: omar92 Date: Mon, 3 Apr 2023 21:27:43 +0200 Subject: [PATCH 33/83] fix bug in reroute node , that didnt allow to load old worflows --- web/extensions/core/rerouteNode.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 1342cae92..aee5d147c 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -50,12 +50,12 @@ app.registerExtension({ } else { // Move the previous node - currentNode = node; + currentNode = node; } } else { // We've found the end inputNode = currentNode; - inputType = node.outputs[link.origin_slot].type; + inputType = node.outputs[link.origin_slot]?.type ?? null; break; } } else { @@ -87,7 +87,7 @@ app.registerExtension({ updateNodes.push(node); } else { // We've found an output - const nodeOutType = node.inputs[link.target_slot].type; + const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; if (inputType && nodeOutType !== inputType) { // The output doesnt match our input so disconnect it node.disconnectInput(link.target_slot); From 539ff487a81f4ed4f51ca9ece57756b573e52190 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 3 Apr 2023 15:49:28 -0400 Subject: [PATCH 34/83] Pull latest tomesd code from upstream. --- comfy/ldm/modules/tomesd.py | 69 ++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 1eafcd0aa..6a13b80c9 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -1,4 +1,4 @@ - +#Taken from: https://github.com/dbolya/tomesd import torch from typing import Tuple, Callable @@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None): return x +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + else: + return torch.gather(input, dim, index) + + def bipartite_soft_matching_random2d(metric: torch.Tensor, w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False) -> Tuple[Callable, Callable]: """ Partitions the tokens into src and dst and merges r tokens from src to dst. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. - Args: - metric [B, N, C]: metric to use for similarity - w: image width in tokens @@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, if r <= 0: return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather with torch.no_grad(): hsy, wsx = h // sy, w // sx # For each sy by sx kernel, randomly assign one token to be dst and the rest src - idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device) - if no_rand: - rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) + rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) else: - rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) + rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) - idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) - idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) - rand_idx = idx_buffer.argsort(dim=1) + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) - num_dst = int((1 / (sx*sy)) * N) + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx a_idx = rand_idx[:, num_dst:, :] # src b_idx = rand_idx[:, :num_dst, :] # dst def split(x): C = x.shape[-1] - src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) - dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) return src, dst + # Cosine similarity between A and B metric = metric / metric.norm(dim=-1, keepdim=True) a, b = split(metric) scores = a @ b.transpose(-1, -2) @@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, # Can't reduce more than the # tokens in src r = min(a.shape[1], r) + # Find the most similar greedily node_max, node_idx = scores.max(dim=-1) edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] unm_idx = edge_idx[..., r:, :] # Unmerged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens - dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) n, t1, c = src.shape - unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) - src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) return torch.cat([unm, dst], dim=1) @@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] _, _, c = unm.shape - src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) # Combine back to the original shape out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) - out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) - out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) return out @@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, def get_functions(x, ratio, original_shape): b, c, original_h, original_w = original_shape original_tokens = original_h * original_w - downsample = int(math.sqrt(original_tokens // x.shape[1])) + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) stride_x = 2 stride_y = 2 max_downsample = 1 if downsample <= max_downsample: - w = original_w // downsample - h = original_h // downsample + w = int(math.ceil(original_w / downsample)) + h = int(math.ceil(original_h / downsample)) r = int(x.shape[1] * ratio) no_rand = False m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) From ca2ae98470fdebd951cdb750998b82ecb532c901 Mon Sep 17 00:00:00 2001 From: omar92 Date: Mon, 3 Apr 2023 22:01:18 +0200 Subject: [PATCH 35/83] check if workflowNode And widgets_values are defined as they were causing errors on QueuePrompt after loading workFlow --- web/extensions/core/dynamicPrompts.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/dynamicPrompts.js b/web/extensions/core/dynamicPrompts.js index 8528201d3..7dae07f4d 100644 --- a/web/extensions/core/dynamicPrompts.js +++ b/web/extensions/core/dynamicPrompts.js @@ -30,7 +30,8 @@ app.registerExtension({ } // Overwrite the value in the serialized workflow pnginfo - workflowNode.widgets_values[widgetIndex] = prompt; + if (workflowNode?.widgets_values) + workflowNode.widgets_values[widgetIndex] = prompt; return prompt; }; From dc24d7e2fd2967de6d3bc45971bfcb0274724f8b Mon Sep 17 00:00:00 2001 From: mligaintart <> Date: Mon, 3 Apr 2023 16:46:00 -0400 Subject: [PATCH 36/83] Adds orientation settings to reroute nodes, allowing for cleaner graphes. --- web/extensions/core/rerouteNode.js | 32 ++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index aee5d147c..c31f63cd0 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -11,11 +11,14 @@ app.registerExtension({ this.properties = {}; } this.properties.showOutputText = RerouteNode.defaultVisibility; + this.properties.horizontal = false; this.addInput("", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*"); this.onConnectionsChange = function (type, index, connected, link_info) { + this.applyOrientation(); + // Prevent multiple connections to different types when we have no input if (connected && type === LiteGraph.OUTPUT) { // Ignore wildcard nodes as these will be updated to real types @@ -49,8 +52,8 @@ app.registerExtension({ currentNode = null; } else { - // Move the previous node - currentNode = node; + // Move the previous node + currentNode = node; } } else { // We've found the end @@ -112,6 +115,7 @@ app.registerExtension({ node.__outputType = displayType; node.outputs[0].name = node.properties.showOutputText ? displayType : ""; node.size = node.computeSize(); + node.applyOrientation(); for (const l of node.outputs[0].links || []) { const link = app.graph.links[l]; @@ -153,6 +157,7 @@ app.registerExtension({ this.outputs[0].name = ""; } this.size = this.computeSize(); + this.applyOrientation(); app.graph.setDirtyCanvas(true, true); }, }, @@ -161,9 +166,32 @@ app.registerExtension({ callback: () => { RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility); }, + }, + { + // naming is inverted with respect to LiteGraphNode.horizontal + // LiteGraphNode.horizontal == true means that + // each slot in the inputs and outputs are layed out horizontally, + // which is the opposite of the visual orientation of the inputs and outputs as a node + content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"), + callback: () => { + this.properties.horizontal = !this.properties.horizontal; + this.applyOrientation(); + }, } ); } + applyOrientation() { + this.horizontal = this.properties.horizontal; + if (this.horizontal) { + // we correct the input position, because LiteGraphNode.horizontal + // doesn't account for title presence + // which reroute nodes don't have + this.inputs[0].pos = [this.size[0] / 2, 0]; + } else { + delete this.inputs[0].pos; + } + app.graph.setDirtyCanvas(true, true); + } computeSize() { return [ From c02baed00fe8ea910d6def31d98308c4a92ae16a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 3 Apr 2023 20:13:43 -0400 Subject: [PATCH 37/83] Add that unCLIP models are supported to the README. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 84e0061ff..0f7d24c45 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) +- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - Starts up very fast. - Works fully offline: will never download anything. - [Config file](extra_model_paths.yaml.example) to set the search paths for models. From 23524ad8c5027e5691d749b6ae778106c469f16a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 3 Apr 2023 22:58:54 -0400 Subject: [PATCH 38/83] Remove print. --- comfy/samplers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 59dbab53d..93f5d361b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -366,7 +366,6 @@ def encode_adm(noise_augmentor, conds, batch_size, device): #TODO: add a way to control this noise_augment = 0.05 noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - print(noise_level) c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) adm_out = torch.cat((c_adm, noise_level_emb), 1) else: From 5036fecbddcd4b7108d196a9c88d91c4480f390f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Apr 2023 03:42:40 -0400 Subject: [PATCH 39/83] Update colab notebook. --- notebooks/comfyui_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index a86ccc753..3e59fbde7 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers==0.0.16 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117" + "!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" ] }, { From 56196ab0f72c8f671bd85b425744f80f02c823ea Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 4 Apr 2023 10:57:34 -0400 Subject: [PATCH 40/83] use common_upcale in blend --- comfy_extras/nodes_post_processing.py | 29 +++++---------------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 322f3ca89..703deaabf 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -3,6 +3,8 @@ import torch import torch.nn.functional as F from PIL import Image +import comfy.utils + class Blend: def __init__(self): @@ -31,7 +33,9 @@ class Blend: def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): if image1.shape != image2.shape: - image2 = self.crop_and_resize(image2, image1.shape) + image2 = image2.permute(0, 3, 1, 2) + image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') + image2 = image2.permute(0, 2, 3, 1) blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor @@ -55,29 +59,6 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) - def crop_and_resize(self, img: torch.Tensor, target_shape: tuple): - batch_size, img_h, img_w, img_c = img.shape - _, target_h, target_w, _ = target_shape - img_aspect_ratio = img_w / img_h - target_aspect_ratio = target_w / target_h - - # Crop center of the image to the target aspect ratio - if img_aspect_ratio > target_aspect_ratio: - new_width = int(img_h * target_aspect_ratio) - left = (img_w - new_width) // 2 - img = img[:, :, left:left + new_width, :] - else: - new_height = int(img_w / target_aspect_ratio) - top = (img_h - new_height) // 2 - img = img[:, top:top + new_height, :, :] - - # Resize to target size - img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False) - img = img.permute(0, 2, 3, 1) - - return img - class Blur: def __init__(self): pass From 1718730e80549c35ce3c5d3fb7926ce5654a2fdd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Apr 2023 11:49:29 -0400 Subject: [PATCH 41/83] Ignore embeddings when sizes don't match and print a WARNING. --- comfy/sd1_clip.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 93036b1ae..4f51657c3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -74,9 +74,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if isinstance(y, int): tokens_temp += [y] else: - embedding_weights += [y] - tokens_temp += [next_new_token] - next_new_token += 1 + if y.shape[0] == current_embeds.weight.shape[1]: + embedding_weights += [y] + tokens_temp += [next_new_token] + next_new_token += 1 + else: + print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) out_tokens += [tokens_temp] if len(embedding_weights) > 0: From 080c758cda19288039de6941876dbdf6f3f9d357 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Tue, 4 Apr 2023 18:16:23 +0200 Subject: [PATCH 42/83] Ask for confirmation before clearing nodes --- web/scripts/ui.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 679f10b20..68bfc792a 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -510,6 +510,7 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { + if (!confirm("Are you sure you want to remove all nodes?")) return; app.clean(); app.graph.clear(); }}), From af291e6f69a66bce6460de58e6e9328f48640dd5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Apr 2023 13:56:13 -0400 Subject: [PATCH 43/83] Convert line endings to unix. --- comfy_extras/nodes_post_processing.py | 420 +++++++++++++------------- 1 file changed, 210 insertions(+), 210 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 703deaabf..de9ef0838 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,210 +1,210 @@ -import numpy as np -import torch -import torch.nn.functional as F -from PIL import Image - -import comfy.utils - - -class Blend: - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "image2": ("IMAGE",), - "blend_factor": ("FLOAT", { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01 - }), - "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blend_images" - - CATEGORY = "postprocessing" - - def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): - if image1.shape != image2.shape: - image2 = image2.permute(0, 3, 1, 2) - image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') - image2 = image2.permute(0, 2, 3, 1) - - blended_image = self.blend_mode(image1, image2, blend_mode) - blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor - blended_image = torch.clamp(blended_image, 0, 1) - return (blended_image,) - - def blend_mode(self, img1, img2, mode): - if mode == "normal": - return img2 - elif mode == "multiply": - return img1 * img2 - elif mode == "screen": - return 1 - (1 - img1) * (1 - img2) - elif mode == "overlay": - return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) - elif mode == "soft_light": - return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) - else: - raise ValueError(f"Unsupported blend mode: {mode}") - - def g(self, x): - return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) - -class Blur: - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "blur_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blur" - - CATEGORY = "postprocessing" - - def gaussian_kernel(self, kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): - if blur_radius == 0: - return (image,) - - batch_size, height, width, channels = image.shape - - kernel_size = blur_radius * 2 + 1 - kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) - - image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) - blurred = blurred.permute(0, 2, 3, 1) - - return (blurred,) - -class Quantize: - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "colors": ("INT", { - "default": 256, - "min": 1, - "max": 256, - "step": 1 - }), - "dither": (["none", "floyd-steinberg"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "quantize" - - CATEGORY = "postprocessing" - - def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): - batch_size, height, width, _ = image.shape - result = torch.zeros_like(image) - - dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE - - for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') - - palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 - quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) - - quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 - result[b] = quantized_array - - return (result,) - -class Sharpen: - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "sharpen_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "alpha": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 5.0, - "step": 0.1 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "sharpen" - - CATEGORY = "postprocessing" - - def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): - if sharpen_radius == 0: - return (image,) - - batch_size, height, width, channels = image.shape - - kernel_size = sharpen_radius * 2 + 1 - kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 - center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha - kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) - - tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) - sharpened = sharpened.permute(0, 2, 3, 1) - - result = torch.clamp(sharpened, 0, 1) - - return (result,) - -NODE_CLASS_MAPPINGS = { - "Blend": Blend, - "Blur": Blur, - "Quantize": Quantize, - "Sharpen": Sharpen, -} +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +import comfy.utils + + +class Blend: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "blend_factor": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01 + }), + "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blend_images" + + CATEGORY = "postprocessing" + + def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + if image1.shape != image2.shape: + image2 = image2.permute(0, 3, 1, 2) + image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') + image2 = image2.permute(0, 2, 3, 1) + + blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor + blended_image = torch.clamp(blended_image, 0, 1) + return (blended_image,) + + def blend_mode(self, img1, img2, mode): + if mode == "normal": + return img2 + elif mode == "multiply": + return img1 * img2 + elif mode == "screen": + return 1 - (1 - img1) * (1 - img2) + elif mode == "overlay": + return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) + elif mode == "soft_light": + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + else: + raise ValueError(f"Unsupported blend mode: {mode}") + + def g(self, x): + return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + +class Blur: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "blur_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "sigma": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blur" + + CATEGORY = "postprocessing" + + def gaussian_kernel(self, kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + + def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + if blur_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + + kernel_size = blur_radius * 2 + 1 + kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + + image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + blurred = blurred.permute(0, 2, 3, 1) + + return (blurred,) + +class Quantize: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "colors": ("INT", { + "default": 256, + "min": 1, + "max": 256, + "step": 1 + }), + "dither": (["none", "floyd-steinberg"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "quantize" + + CATEGORY = "postprocessing" + + def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + + palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) + + quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 + result[b] = quantized_array + + return (result,) + +class Sharpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "sharpen_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 5.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "sharpen" + + CATEGORY = "postprocessing" + + def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + if sharpen_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + + kernel_size = sharpen_radius * 2 + 1 + kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + center = kernel_size // 2 + kernel[center, center] = kernel_size**2 + kernel *= alpha + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) + + tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + sharpened = sharpened.permute(0, 2, 3, 1) + + result = torch.clamp(sharpened, 0, 1) + + return (result,) + +NODE_CLASS_MAPPINGS = { + "Blend": Blend, + "Blur": Blur, + "Quantize": Quantize, + "Sharpen": Sharpen, +} From de3d5f46ce0544339884fe454a59b342fcf28cf3 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 19:32:49 -0600 Subject: [PATCH 44/83] Fix .graphdialog style --- web/style.css | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/web/style.css b/web/style.css index 9162bbba9..393d1667e 100644 --- a/web/style.css +++ b/web/style.css @@ -237,3 +237,28 @@ button.comfy-queue-btn { visibility:hidden } } + +.graphdialog { + min-height: 1em; +} + +.graphdialog .name { + font-size: 14px; + font-family: sans-serif; + color: #999999; +} + +.graphdialog button { + margin-top: unset; + vertical-align: unset; + height: 1.6em; + padding-right: 8px; +} + +.graphdialog input, .graphdialog textarea, .graphdialog select { + background-color: #222; + border: 2px solid; + border-color: #444444; + color: #ddd; + border-radius: 12px 0 0 12px; +} From bf7dbe4702ccfd02f92862238a8da3b6addc656b Mon Sep 17 00:00:00 2001 From: Adam Schwalm Date: Mon, 3 Apr 2023 20:05:46 -0500 Subject: [PATCH 45/83] Add left/right/escape hotkeys for image nodes --- web/scripts/app.js | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 501c7ea65..1ecd4610f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -102,6 +102,46 @@ class ComfyApp { }; } + #addNodeKeyHandler(node) { + const app = this; + const origNodeOnKeyDown = node.prototype.onKeyDown; + + node.prototype.onKeyDown = function(e) { + if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) { + return false; + } + + if (this.flags.collapsed || !this.imgs || this.imageIndex === null) { + return; + } + + let handled = false; + + if (e.key === "ArrowLeft" || e.key === "ArrowRight") { + if (e.key === "ArrowLeft") { + this.imageIndex -= 1; + } else if (e.key === "ArrowRight") { + this.imageIndex += 1; + } + this.imageIndex %= this.imgs.length; + + if (this.imageIndex < 0) { + this.imageIndex = this.imgs.length + this.imageIndex; + } + handled = true; + } else if (e.key === "Escape") { + this.imageIndex = null; + handled = true; + } + + if (handled === true) { + e.preventDefault(); + e.stopImmediatePropagation(); + return false; + } + } + } + /** * Adds Custom drawing logic for nodes * e.g. Draws images and handles thumbnail navigation on nodes that output images @@ -785,6 +825,7 @@ class ComfyApp { this.#addNodeContextMenuHandler(node); this.#addDrawBackgroundHandler(node, app); + this.#addNodeKeyHandler(node); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); LiteGraph.registerNodeType(nodeId, node); From e46b1c3034a23eeb048e279d0d285737d39a4b1a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Apr 2023 22:22:02 -0400 Subject: [PATCH 46/83] Disable xformers in VAE when xformers == 0.0.18 --- comfy/ldm/modules/diffusionmodules/model.py | 4 ++-- comfy/model_management.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 94f5510b9..788a6fc4f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -9,7 +9,7 @@ from typing import Optional, Any from ldm.modules.attention import MemoryEfficientCrossAttention import model_management -if model_management.xformers_enabled(): +if model_management.xformers_enabled_vae(): import xformers import xformers.ops @@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if model_management.xformers_enabled() and attn_type == "vanilla": + if model_management.xformers_enabled_vae() and attn_type == "vanilla": attn_type = "vanilla-xformers" if model_management.pytorch_attention_enabled() and attn_type == "vanilla": attn_type = "vanilla-pytorch" diff --git a/comfy/model_management.py b/comfy/model_management.py index 4aa47ff16..052dfb775 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -199,11 +199,25 @@ def get_autocast_device(dev): return dev.type return "cuda" + def xformers_enabled(): if vram_state == CPU: return False return XFORMERS_IS_AVAILBLE + +def xformers_enabled_vae(): + enabled = xformers_enabled() + if not enabled: + return False + try: + #0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above) + if xformers.version.__version__ == "0.0.18": + return False + except: + pass + return enabled + def pytorch_attention_enabled(): return ENABLE_PYTORCH_ATTENTION From 10ad4c1d17d8ea469565d904a8f47f1d2eeee459 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Apr 2023 22:48:11 -0400 Subject: [PATCH 47/83] Move unclip stuff out of _for_testing --- nodes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 28beb37b1..935e28b84 100644 --- a/nodes.py +++ b/nodes.py @@ -197,7 +197,7 @@ class CheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "advanced/loaders" def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): config_path = folder_paths.get_full_path("configs", config_name) @@ -227,7 +227,7 @@ class unCLIPCheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" - CATEGORY = "_for_testing/unclip" + CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) @@ -450,7 +450,7 @@ class unCLIPConditioning: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" - CATEGORY = "_for_testing/unclip" + CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): c = [] @@ -1038,7 +1038,6 @@ class ImagePadForOutpaint: NODE_CLASS_MAPPINGS = { "KSampler": KSampler, - "CheckpointLoader": CheckpointLoader, "CheckpointLoaderSimple": CheckpointLoaderSimple, "CLIPTextEncode": CLIPTextEncode, "CLIPSetLastLayer": CLIPSetLastLayer, @@ -1077,6 +1076,7 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeTiled": VAEEncodeTiled, "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, + "CheckpointLoader": CheckpointLoader, } def load_custom_node(module_path): From 871a76b77b9cafa8615da1cedeaafc1b10cf85e3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Apr 2023 22:54:33 -0400 Subject: [PATCH 48/83] Rename and reorganize post processing nodes. --- comfy_extras/nodes_post_processing.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index de9ef0838..ba699e2b8 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -29,7 +29,7 @@ class Blend: RETURN_TYPES = ("IMAGE",) FUNCTION = "blend_images" - CATEGORY = "postprocessing" + CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): if image1.shape != image2.shape: @@ -86,7 +86,7 @@ class Blur: RETURN_TYPES = ("IMAGE",) FUNCTION = "blur" - CATEGORY = "postprocessing" + CATEGORY = "image/postprocessing" def gaussian_kernel(self, kernel_size: int, sigma: float): x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") @@ -131,7 +131,7 @@ class Quantize: RETURN_TYPES = ("IMAGE",) FUNCTION = "quantize" - CATEGORY = "postprocessing" + CATEGORY = "image/postprocessing" def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): batch_size, height, width, _ = image.shape @@ -179,7 +179,7 @@ class Sharpen: RETURN_TYPES = ("IMAGE",) FUNCTION = "sharpen" - CATEGORY = "postprocessing" + CATEGORY = "image/postprocessing" def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): if sharpen_radius == 0: @@ -203,8 +203,8 @@ class Sharpen: return (result,) NODE_CLASS_MAPPINGS = { - "Blend": Blend, - "Blur": Blur, - "Quantize": Quantize, - "Sharpen": Sharpen, + "ImageBlend": Blend, + "ImageBlur": Blur, + "ImageQuantize": Quantize, + "ImageSharpen": Sharpen, } From 1b556ea9f43c4ead1235dfefd7d84e193667f6ef Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 21:20:49 -0600 Subject: [PATCH 49/83] Add confirmation for clearing canvas --- web/scripts/ui.js | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 679f10b20..3f6308f24 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -115,6 +115,13 @@ function dragElement(dragEl, settings) { savePos = value; }, }); + + settings.addSetting({ + id: "Comfy.ConfirmClear", + name: "Require confirmation when clearing workflow", + type: "boolean", + defaultValue: false, + }); function dragMouseDown(e) { e = e || window.event; @@ -510,10 +517,16 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { - app.clean(); - app.graph.clear(); + if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { + app.clean(); + app.graph.clear(); + } + }}), + $el("button", { textContent: "Load Default", onclick: () => { + if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { + app.loadGraphData() + } }}), - $el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }), ]); dragElement(this.menuContainer, this.settings); From 30f274bf48419d98f646211df30fe9e074a28a66 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 21:53:02 -0600 Subject: [PATCH 50/83] Make the default true --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 3f6308f24..91821fac0 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -120,7 +120,7 @@ function dragElement(dragEl, settings) { id: "Comfy.ConfirmClear", name: "Require confirmation when clearing workflow", type: "boolean", - defaultValue: false, + defaultValue: true, }); function dragMouseDown(e) { From a595c56872309e310fed7bb877bcd7caee8ef563 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 22:03:22 -0600 Subject: [PATCH 51/83] Remove menu drag handle --- web/scripts/ui.js | 3 +-- web/style.css | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 68bfc792a..df0d8b4a3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -414,8 +414,7 @@ export class ComfyUI { }); this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ - $el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ - $el("span.drag-handle"), + $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [ $el("span", { $: (q) => (this.queueSize = q) }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), diff --git a/web/style.css b/web/style.css index 393d1667e..1263c6648 100644 --- a/web/style.css +++ b/web/style.css @@ -105,7 +105,7 @@ body { background-color: #353535; font-family: sans-serif; padding: 10px; - border-radius: 0 8px 8px 8px; + border-radius: 8px; box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } From 8af2fe1e8747e142e133640659187136eb330d0f Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 22:10:45 -0600 Subject: [PATCH 52/83] Remove redundant lines --- web/style.css | 3 --- 1 file changed, 3 deletions(-) diff --git a/web/style.css b/web/style.css index 1263c6648..c04b40ec4 100644 --- a/web/style.css +++ b/web/style.css @@ -88,13 +88,10 @@ body { } .comfy-menu { - width: 200px; font-size: 15px; position: absolute; top: 50%; right: 0%; - background-color: white; - color: #000; text-align: center; z-index: 100; width: 170px; From 623afa2ced69085d7996921a0d312968a448109b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 10:51:04 +0100 Subject: [PATCH 53/83] Made accessing setting value easier Updated clear check to use this --- web/scripts/ui.js | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 91821fac0..aea1a94b8 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -115,14 +115,6 @@ function dragElement(dragEl, settings) { savePos = value; }, }); - - settings.addSetting({ - id: "Comfy.ConfirmClear", - name: "Require confirmation when clearing workflow", - type: "boolean", - defaultValue: true, - }); - function dragMouseDown(e) { e = e || window.event; e.preventDefault(); @@ -289,6 +281,16 @@ class ComfySettingsDialog extends ComfyDialog { return element; }, }); + + const self = this; + return { + get value() { + return self.getSettingValue(id); + }, + set value(v) { + self.setSettingValue(id, value); + }, + }; } show() { @@ -410,6 +412,13 @@ export class ComfyUI { this.history.update(); }); + const confirmClear = this.settings.addSetting({ + id: "Comfy.ConfirmClear", + name: "Require confirmation when clearing workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { type: "file", accept: ".json,image/png", @@ -517,13 +526,13 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { + if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), $el("button", { textContent: "Load Default", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { + if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } }}), From db16932be5eec5446fbae898ca1365bfae58d90a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 10:52:35 +0100 Subject: [PATCH 54/83] Fix setting --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index aea1a94b8..9952606d4 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -288,7 +288,7 @@ class ComfySettingsDialog extends ComfyDialog { return self.getSettingValue(id); }, set value(v) { - self.setSettingValue(id, value); + self.setSettingValue(id, v); }, }; } From 1030ab0d8fd91e5c1167a087397047603102f069 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 11:02:34 +0100 Subject: [PATCH 55/83] Reload setting value --- web/scripts/ui.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 9952606d4..b6b8e06b2 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -225,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog { }; let element; + value = this.getSettingValue(id, defaultValue); if (typeof type === "function") { element = type(name, setter, value, attrs); @@ -418,7 +419,7 @@ export class ComfyUI { type: "boolean", defaultValue: true, }); - + const fileInput = $el("input", { type: "file", accept: ".json,image/png", From 37713e3b0acfc576f4eafc0b47582374ab5987dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Wed, 5 Apr 2023 21:22:14 +0800 Subject: [PATCH 56/83] Add basic XPU device support closed #387 --- comfy/model_management.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 052dfb775..f0b8be55e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -5,6 +5,7 @@ LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 MPS = 5 +XPU = 6 accelerate_enabled = False vram_state = NORMAL_VRAM @@ -85,10 +86,17 @@ try: except: pass +try: + import intel_extension_for_pytorch + if torch.xpu.is_available(): + vram_state = XPU +except: + pass + if forced_cpu: vram_state = CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state]) current_loaded_model = None @@ -141,6 +149,9 @@ def load_model_gpu(model): mps_device = torch.device("mps") real_model.to(mps_device) pass + elif vram_state == XPU: + real_model.to("xpu") + pass elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: model_accelerated = False real_model.cuda() @@ -189,6 +200,8 @@ def unload_if_low_vram(model): def get_torch_device(): if vram_state == MPS: return torch.device("mps") + if vram_state == XPU: + return torch.device("xpu") if vram_state == CPU: return torch.device("cpu") else: @@ -228,6 +241,9 @@ def get_free_memory(dev=None, torch_free_too=False): if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and (dev.type == 'xpu'): + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) + mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -258,8 +274,12 @@ def mps_mode(): global vram_state return vram_state == MPS +def xpu_mode(): + global vram_state + return vram_state == XPU + def should_use_fp16(): - if cpu_mode() or mps_mode(): + if cpu_mode() or mps_mode() or xpu_mode(): return False #TODO ? if torch.cuda.is_bf16_supported(): From 1ced2bdd2da9a13caf72d7bff36d7f645f443fc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Wed, 5 Apr 2023 21:25:37 +0800 Subject: [PATCH 57/83] Specify safetensors version to avoid upstream errors https://github.com/huggingface/safetensors/issues/142 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3b4040a29..0527b31df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ torchsde einops open-clip-torch transformers>=4.25.1 -safetensors +safetensors>=0.3.0 pytorch_lightning aiohttp accelerate From 3536a7c8d148f738d30a375eab859c74da91a25a Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 5 Apr 2023 08:57:44 -0600 Subject: [PATCH 58/83] Put drag icon back --- web/scripts/ui.js | 1 + web/style.css | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index df0d8b4a3..621ca70ee 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -415,6 +415,7 @@ export class ComfyUI { this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [ + $el("span.drag-handle"), $el("span", { $: (q) => (this.queueSize = q) }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), diff --git a/web/style.css b/web/style.css index c04b40ec4..f2dd4e956 100644 --- a/web/style.css +++ b/web/style.css @@ -102,7 +102,7 @@ body { background-color: #353535; font-family: sans-serif; padding: 10px; - border-radius: 8px; + border-radius: 0 8px 8px 8px; box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } From f816964847d557d2ec94cf52531c43f91751cc28 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Apr 2023 14:01:01 -0400 Subject: [PATCH 59/83] Add a way to set output directory with --output-directory --- folder_paths.py | 34 ++++++++++++++++++++++++++++++++++ main.py | 9 +++++++++ nodes.py | 29 ++++++++++++++--------------- server.py | 6 +++--- 4 files changed, 60 insertions(+), 18 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index af56a6da1..f13e4895f 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") +temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") +input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + +if not os.path.exists(input_directory): + os.makedirs(input_directory) + +def set_output_directory(output_dir): + global output_directory + output_directory = output_dir + +def get_output_directory(): + global output_directory + return output_directory + +def get_temp_directory(): + global temp_directory + return temp_directory + +def get_input_directory(): + global input_directory + return input_directory + + +#NOTE: used in http server so don't put folders that should not be accessed remotely +def get_directory_by_type(type_name): + if type_name == "output": + return get_output_directory() + if type_name == "temp": + return get_temp_directory() + if type_name == "input": + return get_input_directory() + return None + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths diff --git a/main.py b/main.py index fbfaf6be5..a3549b86f 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ if __name__ == "__main__": print("\t--port 8188\t\t\tSet the listen port.") print() print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") + print("\t--output-directory path/to/output\tSet the ComfyUI output directory.") print() print() print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") @@ -134,6 +135,14 @@ if __name__ == "__main__": for i in indices: load_extra_path_config(sys.argv[i]) + try: + output_dir = sys.argv[sys.argv.index('--output-directory') + 1] + output_dir = os.path.abspath(output_dir) + print("setting output directory to:", output_dir) + folder_paths.set_output_directory(output_dir) + except: + pass + port = 8188 try: p_index = sys.argv.index('--port') diff --git a/nodes.py b/nodes.py index 935e28b84..187d54a11 100644 --- a/nodes.py +++ b/nodes.py @@ -777,7 +777,7 @@ class KSamplerAdvanced: class SaveImage: def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + self.output_dir = folder_paths.get_output_directory() self.type = "output" @classmethod @@ -829,9 +829,6 @@ class SaveImage: os.makedirs(full_output_folder, exist_ok=True) counter = 1 - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - results = list() for image in images: i = 255. * image.cpu().numpy() @@ -856,7 +853,7 @@ class SaveImage: class PreviewImage(SaveImage): def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + self.output_dir = folder_paths.get_temp_directory() self.type = "temp" @classmethod @@ -867,13 +864,11 @@ class PreviewImage(SaveImage): } class LoadImage: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod def INPUT_TYPES(s): - if not os.path.exists(s.input_dir): - os.makedirs(s.input_dir) + input_dir = folder_paths.get_input_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), )}, + {"image": (sorted(os.listdir(input_dir)), )}, } CATEGORY = "image" @@ -881,7 +876,8 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - image_path = os.path.join(self.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -895,18 +891,19 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - image_path = os.path.join(s.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() class LoadImageMask: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), ), + {"image": (sorted(os.listdir(input_dir)), ), "channel": (["alpha", "red", "green", "blue"], ),} } @@ -915,7 +912,8 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - image_path = os.path.join(self.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) i = Image.open(image_path) mask = None c = channel[0].upper() @@ -930,7 +928,8 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - image_path = os.path.join(s.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) diff --git a/server.py b/server.py index 963daefff..840d9a4e7 100644 --- a/server.py +++ b/server.py @@ -89,7 +89,7 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + upload_dir = folder_paths.get_input_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -122,10 +122,10 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: type = request.rel_url.query.get("type", "output") - if type not in ["output", "input", "temp"]: + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) - output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) if "subfolder" in request.rel_url.query: full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: From 5456b7555c6cc40a302ac9404603bfdf9c08f95c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 19:58:06 +0100 Subject: [PATCH 60/83] Add missing defaultValue arg --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index b6b8e06b2..3af29ba73 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -286,7 +286,7 @@ class ComfySettingsDialog extends ComfyDialog { const self = this; return { get value() { - return self.getSettingValue(id); + return self.getSettingValue(id, defaultValue); }, set value(v) { self.setSettingValue(id, v); From 1a74611c6e725f1ffb6629d08fbd04bb658f2704 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 5 Apr 2023 15:56:41 -0600 Subject: [PATCH 61/83] Style modals to match rest of UI --- web/scripts/ui.js | 32 +++++++++++++-------- web/style.css | 71 +++++++++++++++++++++++------------------------ 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 91821fac0..4ef24e007 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -115,14 +115,6 @@ function dragElement(dragEl, settings) { savePos = value; }, }); - - settings.addSetting({ - id: "Comfy.ConfirmClear", - name: "Require confirmation when clearing workflow", - type: "boolean", - defaultValue: true, - }); - function dragMouseDown(e) { e = e || window.event; e.preventDefault(); @@ -170,7 +162,7 @@ class ComfyDialog { $el("p", { $: (p) => (this.textElement = p) }), $el("button", { type: "button", - textContent: "CLOSE", + textContent: "Close", onclick: () => this.close(), }), ]), @@ -233,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog { }; let element; + value = this.getSettingValue(id, defaultValue); if (typeof type === "function") { element = type(name, setter, value, attrs); @@ -289,6 +282,16 @@ class ComfySettingsDialog extends ComfyDialog { return element; }, }); + + const self = this; + return { + get value() { + return self.getSettingValue(id, defaultValue); + }, + set value(v) { + self.setSettingValue(id, v); + }, + }; } show() { @@ -410,6 +413,13 @@ export class ComfyUI { this.history.update(); }); + const confirmClear = this.settings.addSetting({ + id: "Comfy.ConfirmClear", + name: "Require confirmation when clearing workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { type: "file", accept: ".json,image/png", @@ -517,13 +527,13 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { + if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), $el("button", { textContent: "Load Default", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { + if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } }}), diff --git a/web/style.css b/web/style.css index 393d1667e..d347bd454 100644 --- a/web/style.css +++ b/web/style.css @@ -39,18 +39,19 @@ body { position: fixed; /* Stay in place */ z-index: 100; /* Sit on top */ padding: 30px 30px 10px 30px; - background-color: #ff0000; /* Modal background */ + background-color: #353535; /* Modal background */ + color: #ff4444; box-shadow: 0px 0px 20px #888888; border-radius: 10px; - text-align: center; top: 50%; left: 50%; max-width: 80vw; max-height: 80vh; transform: translate(-50%, -50%); overflow: hidden; - min-width: 60%; justify-content: center; + font-family: monospace; + font-size: 15px; } .comfy-modal-content { @@ -70,23 +71,6 @@ body { margin: 3px 3px 3px 4px; } -.comfy-modal button { - cursor: pointer; - color: #aaaaaa; - border: none; - background-color: transparent; - font-size: 24px; - font-weight: bold; - width: 100%; -} - -.comfy-modal button:hover, -.comfy-modal button:focus { - color: #000; - text-decoration: none; - cursor: pointer; -} - .comfy-menu { width: 200px; font-size: 15px; @@ -109,7 +93,8 @@ body { box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } -.comfy-menu button { +.comfy-menu button, +.comfy-modal button { font-size: 20px; } @@ -130,7 +115,8 @@ body { .comfy-menu > button, .comfy-menu-btns button, -.comfy-menu .comfy-list button { +.comfy-menu .comfy-list button, +.comfy-modal button{ color: #ddd; background-color: #222; border-radius: 8px; @@ -220,11 +206,22 @@ button.comfy-queue-btn { } .comfy-modal.comfy-settings { - background-color: var(--bg-color); - color: var(--fg-color); + text-align: center; + font-family: sans-serif; + color: #999; z-index: 99; } +.comfy-modal input, +.comfy-modal select { + color: #ddd; + background-color: #222; + border-radius: 8px; + border-color: #4e4e4e; + border-style: solid; + font-size: inherit; +} + @media only screen and (max-height: 850px) { .comfy-menu { top: 0 !important; @@ -239,26 +236,26 @@ button.comfy-queue-btn { } .graphdialog { - min-height: 1em; + min-height: 1em; } .graphdialog .name { - font-size: 14px; - font-family: sans-serif; - color: #999999; + font-size: 14px; + font-family: sans-serif; + color: #999999; } .graphdialog button { - margin-top: unset; - vertical-align: unset; - height: 1.6em; - padding-right: 8px; + margin-top: unset; + vertical-align: unset; + height: 1.6em; + padding-right: 8px; } .graphdialog input, .graphdialog textarea, .graphdialog select { - background-color: #222; - border: 2px solid; - border-color: #444444; - color: #ddd; - border-radius: 12px 0 0 12px; + background-color: #222; + border: 2px solid; + border-color: #444444; + color: #ddd; + border-radius: 12px 0 0 12px; } From dd29966f8a2973529ea50de2ef3d0e7c72b5114e Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 5 Apr 2023 20:32:59 -0400 Subject: [PATCH 62/83] changes main.py to use argparse --- main.py | 118 ++++++++++++++++++++++---------------------------------- 1 file changed, 47 insertions(+), 71 deletions(-) diff --git a/main.py b/main.py index a3549b86f..20c8a49e8 100644 --- a/main.py +++ b/main.py @@ -1,57 +1,54 @@ -import os -import sys -import shutil - -import threading +import argparse import asyncio +import os +import shutil +import sys +import threading if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": - if '--help' in sys.argv: - print() - print("Valid Command line Arguments:") - print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.") - print("\t--port 8188\t\t\tSet the listen port.") - print() - print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") - print("\t--output-directory path/to/output\tSet the ComfyUI output directory.") - print() - print() - print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") - print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") - print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") - print("\t--disable-xformers\t\tdisables xformers") - print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.") - print() - print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n") - print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.") - print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.") - print("\t--novram\t\t\tWhen lowvram isn't enough.") - print() - print("\t--cpu\t\t\tTo use the CPU for everything (slow).") - exit() + parser = argparse.ArgumentParser(description="Script Arguments") - if '--dont-upcast-attention' in sys.argv: + parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 0.0.0.0 if none given so the UI can be accessed from other computers.") + parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") + parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") + parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") + parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") + parser.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") + parser.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") + parser.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") + parser.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") + parser.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") + parser.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") + parser.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") + parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") + parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") + + args = parser.parse_args() + + if args.dont_upcast_attention: print("disabling upcasting of attention") os.environ['ATTN_PRECISION'] = "fp16" - try: - index = sys.argv.index('--cuda-device') - device = sys.argv[index + 1] - os.environ['CUDA_VISIBLE_DEVICES'] = device - print("Set cuda device to:", device) - except: - pass + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + print("Set cuda device to:", args.cuda_device) + -from nodes import init_custom_nodes -import execution -import server -import folder_paths import yaml +import execution +import folder_paths +import server +from nodes import init_custom_nodes + + def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: @@ -110,51 +107,30 @@ if __name__ == "__main__": hijack_progress(server) threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - try: - address = '0.0.0.0' - p_index = sys.argv.index('--listen') - try: - ip = sys.argv[p_index + 1] - if ip[:2] != '--': - address = ip - except: - pass - except: - address = '127.0.0.1' - dont_print = False - if '--dont-print-server' in sys.argv: - dont_print = True + address = args.listen + + dont_print = args.dont_print_server extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): load_extra_path_config(extra_model_paths_config_path) - if '--extra-model-paths-config' in sys.argv: - indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] - for i in indices: - load_extra_path_config(sys.argv[i]) + if args.extra_model_paths_config: + load_extra_path_config(args.extra_model_paths_config) - try: - output_dir = sys.argv[sys.argv.index('--output-directory') + 1] - output_dir = os.path.abspath(output_dir) + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) print("setting output directory to:", output_dir) folder_paths.set_output_directory(output_dir) - except: - pass - port = 8188 - try: - p_index = sys.argv.index('--port') - port = int(sys.argv[p_index + 1]) - except: - pass + port = args.port - if '--quick-test-for-ci' in sys.argv: + if args.quick_test_for_ci: exit(0) call_on_start = None - if "--windows-standalone-build" in sys.argv: + if args.windows_standalone_build: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port)) From e5e587b1c0c5dc728d65b3e84592445cdb5e6e9b Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 5 Apr 2023 23:41:23 -0400 Subject: [PATCH 63/83] seperates out arg parser and imports args --- comfy/cli_args.py | 29 +++++++++ comfy/ldm/modules/attention.py | 5 +- comfy/model_management.py | 111 ++++++++++++++++----------------- main.py | 27 +------- 4 files changed, 88 insertions(+), 84 deletions(-) create mode 100644 comfy/cli_args.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py new file mode 100644 index 000000000..6a56e315c --- /dev/null +++ b/comfy/cli_args.py @@ -0,0 +1,29 @@ +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 127.0.0.1 if none given so the UI can be accessed from other computers.") +parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") +parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") + +attn_group = parser.add_mutually_exclusive_group() +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + +parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") +parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") + +vram_group = parser.add_mutually_exclusive_group() +vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") +vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") +vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") +vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") +vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + +parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") +parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") +parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") + +args = parser.parse_args() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 07553627c..92b3eca7c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -21,6 +21,8 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") +from cli_args import args + def exists(val): return val is not None @@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module): return self.to_out(out) -import sys if model_management.xformers_enabled(): print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention @@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled(): print("Using pytorch cross attention") CrossAttention = CrossAttentionPytorch else: - if "--use-split-cross-attention" in sys.argv: + if args.use_split_cross_attention: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 052dfb775..7dda073dc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,36 +1,35 @@ +import psutil +from enum import Enum +from cli_args import args -CPU = 0 -NO_VRAM = 1 -LOW_VRAM = 2 -NORMAL_VRAM = 3 -HIGH_VRAM = 4 -MPS = 5 +class VRAMState(Enum): + CPU = 0 + NO_VRAM = 1 + LOW_VRAM = 2 + NORMAL_VRAM = 3 + HIGH_VRAM = 4 + MPS = 5 -accelerate_enabled = False -vram_state = NORMAL_VRAM +# Determine VRAM State +vram_state = VRAMState.NORMAL_VRAM +set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 total_vram_available_mb = -1 -import sys -import psutil - -forced_cpu = "--cpu" in sys.argv - -set_vram_to = NORMAL_VRAM +accelerate_enabled = False try: import torch total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) - forced_normal_vram = "--normalvram" in sys.argv - if not forced_normal_vram and not forced_cpu: + if not args.normalvram and not args.cpu: if total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = LOW_VRAM + set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = HIGH_VRAM + vram_state = VRAMState.HIGH_VRAM except: pass @@ -39,34 +38,32 @@ try: except: OOM_EXCEPTION = Exception -if "--disable-xformers" in sys.argv: - XFORMERS_IS_AVAILBLE = False +if args.disable_xformers: + XFORMERS_IS_AVAILABLE = False else: try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True + XFORMERS_IS_AVAILABLE = True except: - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False -ENABLE_PYTORCH_ATTENTION = False -if "--use-pytorch-cross-attention" in sys.argv: +ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - ENABLE_PYTORCH_ATTENTION = True - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False + +if args.lowvram: + set_vram_to = VRAMState.LOW_VRAM +elif args.novram: + set_vram_to = VRAMState.NO_VRAM +elif args.highvram: + vram_state = VRAMState.HIGH_VRAM -if "--lowvram" in sys.argv: - set_vram_to = LOW_VRAM -if "--novram" in sys.argv: - set_vram_to = NO_VRAM -if "--highvram" in sys.argv: - vram_state = HIGH_VRAM - - -if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: +if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): try: import accelerate accelerate_enabled = True @@ -81,14 +78,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: if torch.backends.mps.is_available(): - vram_state = MPS + vram_state = VRAMState.MPS except: pass -if forced_cpu: - vram_state = CPU +if args.cpu: + vram_state = VRAMState.CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) +print(f"Set vram state to: {vram_state.name}") current_loaded_model = None @@ -109,12 +106,12 @@ def unload_model(): model_accelerated = False #never unload models from GPU on high vram - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: if len(current_gpu_controlnets) > 0: for n in current_gpu_controlnets: n.cpu() @@ -135,19 +132,19 @@ def load_model_gpu(model): model.unpatch_model() raise e current_loaded_model = model - if vram_state == CPU: + if vram_state == VRAMState.CPU: pass - elif vram_state == MPS: + elif vram_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: + elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.cuda() else: - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == LOW_VRAM: + elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") @@ -157,10 +154,10 @@ def load_model_gpu(model): def load_controlnet_gpu(models): global current_gpu_controlnets global vram_state - if vram_state == CPU: + if vram_state == VRAMState.CPU: return - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return @@ -176,20 +173,20 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cuda() return model def unload_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cpu() return model def get_torch_device(): - if vram_state == MPS: + if vram_state == VRAMState.MPS: return torch.device("mps") - if vram_state == CPU: + if vram_state == VRAMState.CPU: return torch.device("cpu") else: return torch.cuda.current_device() @@ -201,9 +198,9 @@ def get_autocast_device(dev): def xformers_enabled(): - if vram_state == CPU: + if vram_state == VRAMState.CPU: return False - return XFORMERS_IS_AVAILBLE + return XFORMERS_IS_AVAILABLE def xformers_enabled_vae(): @@ -243,7 +240,7 @@ def get_free_memory(dev=None, torch_free_too=False): def maximum_batch_area(): global vram_state - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: return 0 memory_free = get_free_memory() / (1024 * 1024) @@ -252,11 +249,11 @@ def maximum_batch_area(): def cpu_mode(): global vram_state - return vram_state == CPU + return vram_state == VRAMState.CPU def mps_mode(): global vram_state - return vram_state == MPS + return vram_state == VRAMState.MPS def should_use_fp16(): if cpu_mode() or mps_mode(): diff --git a/main.py b/main.py index 20c8a49e8..51a48fc6d 100644 --- a/main.py +++ b/main.py @@ -1,37 +1,14 @@ -import argparse import asyncio import os import shutil -import sys import threading +from comfy.cli_args import args if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Script Arguments") - - parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 0.0.0.0 if none given so the UI can be accessed from other computers.") - parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") - parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") - parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") - parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") - parser.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") - parser.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") - parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") - parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") - parser.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") - parser.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") - parser.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") - parser.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") - parser.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") - parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") - parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") - parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") - - args = parser.parse_args() - if args.dont_upcast_attention: print("disabling upcasting of attention") os.environ['ATTN_PRECISION'] = "fp16" @@ -121,7 +98,7 @@ if __name__ == "__main__": if args.output_directory: output_dir = os.path.abspath(args.output_directory) - print("setting output directory to:", output_dir) + print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) port = args.port From 84b9c0ac2ff49b5b18b8e7804f8fe42a379a0787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Thu, 6 Apr 2023 12:27:22 +0800 Subject: [PATCH 64/83] Import intel_extension_for_pytorch as ipex --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f0b8be55e..379cc18d7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -87,7 +87,7 @@ except: pass try: - import intel_extension_for_pytorch + import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): vram_state = XPU except: From 7cb924f68469cd2481b2313f8e5fc02587279bf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Thu, 6 Apr 2023 14:24:47 +0800 Subject: [PATCH 65/83] Use separate variables instead of `vram_state` --- comfy/model_management.py | 70 +++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 379cc18d7..a84167746 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -5,9 +5,9 @@ LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 MPS = 5 -XPU = 6 accelerate_enabled = False +xpu_available = False vram_state = NORMAL_VRAM total_vram = 0 @@ -22,7 +22,12 @@ set_vram_to = NORMAL_VRAM try: import torch - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + else: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) forced_normal_vram = "--normalvram" in sys.argv if not forced_normal_vram and not forced_cpu: @@ -86,17 +91,10 @@ try: except: pass -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - vram_state = XPU -except: - pass - if forced_cpu: vram_state = CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state]) +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) current_loaded_model = None @@ -133,6 +131,7 @@ def load_model_gpu(model): global current_loaded_model global vram_state global model_accelerated + global xpu_available if model is current_loaded_model: return @@ -149,19 +148,19 @@ def load_model_gpu(model): mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == XPU: - real_model.to("xpu") - pass elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: model_accelerated = False - real_model.cuda() + if xpu_available: + real_model.to("xpu") + else: + real_model.cuda() else: if vram_state == NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_state == LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") + accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda") model_accelerated = True return current_loaded_model @@ -187,8 +186,12 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state + global xpu_available if vram_state == LOW_VRAM or vram_state == NO_VRAM: - return model.cuda() + if xpu_available: + return model.to("xpu") + else: + return model.cuda() return model def unload_if_low_vram(model): @@ -198,14 +201,16 @@ def unload_if_low_vram(model): return model def get_torch_device(): + global xpu_available if vram_state == MPS: return torch.device("mps") - if vram_state == XPU: - return torch.device("xpu") if vram_state == CPU: return torch.device("cpu") else: - return torch.cuda.current_device() + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -235,22 +240,24 @@ def pytorch_attention_enabled(): return ENABLE_PYTORCH_ATTENTION def get_free_memory(dev=None, torch_free_too=False): + global xpu_available if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total - elif hasattr(dev, 'type') and (dev.type == 'xpu'): - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) - mem_free_torch = mem_free_total else: - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + if xpu_available: + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) + mem_free_torch = mem_free_total + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch if torch_free_too: return (mem_free_total, mem_free_torch) @@ -274,12 +281,9 @@ def mps_mode(): global vram_state return vram_state == MPS -def xpu_mode(): - global vram_state - return vram_state == XPU - def should_use_fp16(): - if cpu_mode() or mps_mode() or xpu_mode(): + global xpu_available + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? if torch.cuda.is_bf16_supported(): From 60127a83040b3b243457980d04f3bb25c4491978 Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Wed, 5 Apr 2023 23:57:31 -0700 Subject: [PATCH 66/83] diffusers loader --- comfy/diffusers_convert.py | 364 +++++++++++++++++++++ models/diffusers/put_diffusers_models_here | 0 nodes.py | 19 +- 3 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 comfy/diffusers_convert.py create mode 100644 models/diffusers/put_diffusers_models_here diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py new file mode 100644 index 000000000..a31c1c11b --- /dev/null +++ b/comfy/diffusers_convert.py @@ -0,0 +1,364 @@ +import json +import os +import yaml + +# because of local import nonsense +import sys +sys.path.append(os.path.dirname(os.path.realpath(__file__))) + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file + +# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py + +# =================# +# UNet Conversion # +# =================# + +unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), +] + +unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), +] + +unet_conversion_map_layer = [] +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +def convert_unet_state_dict(unet_state_dict): + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + +vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), +] + +for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3 - i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + +# this part accounts for mid blocks in both the encoder and the decoder +for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i + 1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), +] + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + return new_state_dict + + +# =========================# +# Text Encoder Conversion # +# =========================# + + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/models/diffusers/put_diffusers_models_here b/models/diffusers/put_diffusers_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 187d54a11..776bc3819 100644 --- a/nodes.py +++ b/nodes.py @@ -4,13 +4,14 @@ import os import sys import json import hashlib -import copy import traceback from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np +from comfy.diffusers_convert import load_diffusers + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -219,6 +220,21 @@ class CheckpointLoaderSimple: out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out +class DiffusersLoader: + @classmethod + def INPUT_TYPES(cls): + return {"required": {"model_path": (os.listdir(os.path.join(folder_paths.models_dir, 'diffusers'), ),), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders" + + def load_checkpoint(self, model_path, output_vae=True, output_clip=True): + model_path = os.path.join(folder_paths.models_dir, 'diffusers', model_path) + return load_diffusers(model_path, fp16=True, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + + class unCLIPCheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -1076,6 +1092,7 @@ NODE_CLASS_MAPPINGS = { "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "CheckpointLoader": CheckpointLoader, + "DiffusersLoader": DiffusersLoader, } def load_custom_node(module_path): From c418d988ba59b3114770a0fa111d301f04880fca Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Wed, 5 Apr 2023 23:59:03 -0700 Subject: [PATCH 67/83] update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0f7d24c45..90931141d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Works even if you don't have a GPU with: ```--cpu``` (slow) -- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. +- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - Loading full workflows (with seeds) from generated PNG files. From 3d16077e3806b0817b1d43dc14f61e5dee5495c8 Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Thu, 6 Apr 2023 00:24:52 -0700 Subject: [PATCH 68/83] empty list if diffusers directory doesn't exist --- nodes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 776bc3819..1af62887d 100644 --- a/nodes.py +++ b/nodes.py @@ -223,8 +223,11 @@ class CheckpointLoaderSimple: class DiffusersLoader: @classmethod def INPUT_TYPES(cls): - return {"required": {"model_path": (os.listdir(os.path.join(folder_paths.models_dir, 'diffusers'), ),), - }} + paths = [] + search_path = os.path.join(folder_paths.models_dir, 'diffusers') + if os.path.exists(search_path): + paths = next(os.walk(search_path))[1] + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" From 42fd67b5cb0de9bd1228af7a93dec08b2f1486c3 Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Thu, 6 Apr 2023 00:28:06 -0700 Subject: [PATCH 69/83] use precision determined by model management --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 1af62887d..8271da04c 100644 --- a/nodes.py +++ b/nodes.py @@ -235,7 +235,7 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): model_path = os.path.join(folder_paths.models_dir, 'diffusers', model_path) - return load_diffusers(model_path, fp16=True, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: From 3e2608e12b312fd5d2396d4146d992cd4f8b9ab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Thu, 6 Apr 2023 15:44:05 +0800 Subject: [PATCH 70/83] Fix auto lowvram detection on CUDA --- comfy/model_management.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a84167746..b0123b5fc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,11 +22,12 @@ set_vram_to = NORMAL_VRAM try: import torch - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - else: + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + except: total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) forced_normal_vram = "--normalvram" in sys.argv From 01c1fc669fb8cd41f627dad871257acbaaf24b47 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 6 Apr 2023 13:19:00 -0400 Subject: [PATCH 71/83] set listen flag to listen on all if specifed --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 6a56e315c..a27dc7a7f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -2,7 +2,7 @@ import argparse parser = argparse.ArgumentParser() -parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 127.0.0.1 if none given so the UI can be accessed from other computers.") +parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") From 7d62d89f9325348179fc9b0db146ff50fa7c808c Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 5 Apr 2023 13:08:08 -0400 Subject: [PATCH 72/83] add cors middleware --- server.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 840d9a4e7..005bf9b2c 100644 --- a/server.py +++ b/server.py @@ -27,6 +27,19 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response +@web.middleware +async def cors_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -37,7 +50,7 @@ class PromptServer(): self.loop = loop self.messages = asyncio.Queue() self.number = 0 - self.app = web.Application(client_max_size=20971520, middlewares=[cache_control]) + self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware]) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") From 48efae16084b423166f9a1930b989489169d22cf Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 6 Apr 2023 15:06:22 -0400 Subject: [PATCH 73/83] makes cors a cli parameter --- comfy/cli_args.py | 3 ++- server.py | 36 +++++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a27dc7a7f..5133e0ae5 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -4,8 +4,10 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") attn_group = parser.add_mutually_exclusive_group() @@ -13,7 +15,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") -parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") diff --git a/server.py b/server.py index 005bf9b2c..a9c0b4599 100644 --- a/server.py +++ b/server.py @@ -18,6 +18,7 @@ except ImportError: sys.exit() import mimetypes +from comfy.cli_args import args @web.middleware @@ -27,18 +28,22 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response -@web.middleware -async def cors_middleware(request: web.Request, handler): - if request.method == "OPTIONS": - # Pre-flight request. Reply successfully: - response = web.Response() - else: - response = await handler(request) - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' - response.headers['Access-Control-Allow-Credentials'] = 'true' - return response +def create_cors_middleware(allowed_origin: str): + @web.middleware + async def cors_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Access-Control-Allow-Origin'] = allowed_origin + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + + return cors_middleware class PromptServer(): def __init__(self, loop): @@ -50,7 +55,12 @@ class PromptServer(): self.loop = loop self.messages = asyncio.Queue() self.number = 0 - self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware]) + + middlewares = [cache_control] + if args.cors: + middlewares.append(create_cors_middleware(args.cors)) + + self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") From f84f2508cc45a014cc27e023e9623db0450d237e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Apr 2023 15:24:55 -0400 Subject: [PATCH 74/83] Rename the cors parameter to something more verbose. --- comfy/cli_args.py | 2 +- server.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5133e0ae5..f2960ae31 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -4,7 +4,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") -parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--enable-cors-header", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") diff --git a/server.py b/server.py index a9c0b4599..95cdeb051 100644 --- a/server.py +++ b/server.py @@ -57,8 +57,8 @@ class PromptServer(): self.number = 0 middlewares = [cache_control] - if args.cors: - middlewares.append(create_cors_middleware(args.cors)) + if args.enable_cors_header: + middlewares.append(create_cors_middleware(args.enable_cors_header)) self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.sockets = dict() From 28fff5d1dbba8b4a546e31c69240133f35b2235f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 6 Apr 2023 19:06:39 -0400 Subject: [PATCH 75/83] fixes lack of support for multi configs also adds some metavars to argarse --- comfy/cli_args.py | 8 ++++---- main.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index f2960ae31..b6898cea9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -2,12 +2,12 @@ import argparse parser = argparse.ArgumentParser() -parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") +parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") -parser.add_argument("--enable-cors-header", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") -parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") +parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") -parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") +parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") attn_group = parser.add_mutually_exclusive_group() diff --git a/main.py b/main.py index 51a48fc6d..9c0a3d8a1 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,9 @@ import asyncio +import itertools import os import shutil import threading + from comfy.cli_args import args if os.name == "nt": @@ -94,7 +96,8 @@ if __name__ == "__main__": load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: - load_extra_path_config(args.extra_model_paths_config) + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) if args.output_directory: output_dir = os.path.abspath(args.output_directory) From 60b4c31ab3c2ec16575c26d9d08ecabc8643b381 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Apr 2023 22:22:59 -0400 Subject: [PATCH 76/83] Add webp images to upload accept list. --- web/scripts/widgets.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 5f5043cd0..d1a9c6c6e 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -306,7 +306,7 @@ export const ComfyWidgets = { const fileInput = document.createElement("input"); Object.assign(fileInput, { type: "file", - accept: "image/jpeg,image/png", + accept: "image/jpeg,image/png,image/webp", style: "display: none", onchange: async () => { if (fileInput.files.length) { From bceccca0e59862c3410b5d99b47fe1e01ba914af Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Apr 2023 23:52:34 -0400 Subject: [PATCH 77/83] Small refactor. --- comfy/model_management.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 92c59efe7..504da2190 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -129,7 +129,6 @@ def load_model_gpu(model): global current_loaded_model global vram_state global model_accelerated - global xpu_available if model is current_loaded_model: return @@ -148,17 +147,14 @@ def load_model_gpu(model): pass elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False - if xpu_available: - real_model.to("xpu") - else: - real_model.cuda() + real_model.to(get_torch_device()) else: if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda") + accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True return current_loaded_model @@ -184,12 +180,8 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state - global xpu_available if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - if xpu_available: - return model.to("xpu") - else: - return model.cuda() + return model.to(get_torch_device()) return model def unload_if_low_vram(model): From 64557d67810c81f72bd6a7544bd8930488868319 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Apr 2023 00:27:54 -0400 Subject: [PATCH 78/83] Add a --force-fp32 argument to force fp32 for debugging. --- comfy/cli_args.py | 1 + comfy/model_management.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b6898cea9..739891f71 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -9,6 +9,7 @@ parser.add_argument("--extra-model-paths-config", type=str, default=None, metava parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") +parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 504da2190..2407140fd 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -69,6 +69,11 @@ elif args.novram: elif args.highvram: vram_state = VRAMState.HIGH_VRAM +FORCE_FP32 = False +if args.force_fp32: + print("Forcing FP32, if this improves things please report it.") + FORCE_FP32 = True + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): try: @@ -273,6 +278,9 @@ def mps_mode(): def should_use_fp16(): global xpu_available + if FORCE_FP32: + return False + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? From 72a8973bd56b7cc179eb603ccd61385fdca5766d Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Thu, 6 Apr 2023 21:45:08 -0700 Subject: [PATCH 79/83] allow configurable path for diffusers models --- folder_paths.py | 1 + nodes.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index f13e4895f..ab3359347 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -23,6 +23,7 @@ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) +folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) diff --git a/nodes.py b/nodes.py index 8271da04c..934b458f2 100644 --- a/nodes.py +++ b/nodes.py @@ -224,7 +224,7 @@ class DiffusersLoader: @classmethod def INPUT_TYPES(cls): paths = [] - search_path = os.path.join(folder_paths.models_dir, 'diffusers') + search_path = folder_paths.get_folder_paths("diffusers")[0] if os.path.exists(search_path): paths = next(os.walk(search_path))[1] return {"required": {"model_path": (paths,), }} From f51b7a92c72b5fe7a12d642a545e59f1f6150fb4 Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Thu, 6 Apr 2023 21:48:58 -0700 Subject: [PATCH 80/83] search all diffusers paths (oops) --- nodes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 934b458f2..a4366f834 100644 --- a/nodes.py +++ b/nodes.py @@ -224,9 +224,10 @@ class DiffusersLoader: @classmethod def INPUT_TYPES(cls): paths = [] - search_path = folder_paths.get_folder_paths("diffusers")[0] - if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] + search_paths = folder_paths.get_folder_paths("diffusers") + for search_path in search_paths: + if os.path.exists(search_path): + paths = next(os.walk(search_path))[1] return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" From 7734d65f22a8f30f73cb72e81586b2d015229060 Mon Sep 17 00:00:00 2001 From: sALTaccount Date: Thu, 6 Apr 2023 22:02:26 -0700 Subject: [PATCH 81/83] fix loading alt folders --- nodes.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index a4366f834..274ae2f1f 100644 --- a/nodes.py +++ b/nodes.py @@ -224,10 +224,9 @@ class DiffusersLoader: @classmethod def INPUT_TYPES(cls): paths = [] - search_paths = folder_paths.get_folder_paths("diffusers") - for search_path in search_paths: + for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] + paths += next(os.walk(search_path))[1] return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -235,7 +234,13 @@ class DiffusersLoader: CATEGORY = "loaders" def load_checkpoint(self, model_path, output_vae=True, output_clip=True): - model_path = os.path.join(folder_paths.models_dir, 'diffusers', model_path) + for search_path in folder_paths.get_folder_paths("diffusers"): + if os.path.exists(search_path): + paths = next(os.walk(search_path))[1] + if model_path in paths: + model_path = os.path.join(search_path, model_path) + break + search_paths = folder_paths.get_folder_paths("diffusers") return load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) From 58ed0f2da438aaf253f9880578d694ad917819f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Apr 2023 01:28:15 -0400 Subject: [PATCH 82/83] Fix loading SD1.5 diffusers checkpoint. --- comfy/diffusers_convert.py | 4 +++- nodes.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index a31c1c11b..950137f2c 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -272,7 +272,8 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb # magic v2 = diffusers_unet_conf["sample_size"] == 96 - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' if v2: if v_pred: @@ -290,6 +291,7 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb scale_factor = model_config_params['scale_factor'] vae_config = model_config_params['first_stage_config'] vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") diff --git a/nodes.py b/nodes.py index 274ae2f1f..025e4fcb4 100644 --- a/nodes.py +++ b/nodes.py @@ -231,7 +231,7 @@ class DiffusersLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "advanced/loaders" def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): @@ -240,7 +240,7 @@ class DiffusersLoader: if model_path in paths: model_path = os.path.join(search_path, model_path) break - search_paths = folder_paths.get_folder_paths("diffusers") + return load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) From 44fea050649347ca4b4e7317a83d11c3b4b87f87 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Apr 2023 02:29:56 -0400 Subject: [PATCH 83/83] Cleanup. --- comfy/diffusers_convert.py | 4 ---- nodes.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 950137f2c..ceca80305 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -2,10 +2,6 @@ import json import os import yaml -# because of local import nonsense -import sys -sys.path.append(os.path.dirname(os.path.realpath(__file__))) - import folder_paths from comfy.ldm.util import instantiate_from_config from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE diff --git a/nodes.py b/nodes.py index 025e4fcb4..5c3b3a4ee 100644 --- a/nodes.py +++ b/nodes.py @@ -10,11 +10,11 @@ from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np -from comfy.diffusers_convert import load_diffusers sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) +import comfy.diffusers_convert import comfy.samplers import comfy.sd import comfy.utils @@ -241,7 +241,7 @@ class DiffusersLoader: model_path = os.path.join(search_path, model_path) break - return load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: