From 700491d81a9faf5370a0c54d869e902bbfc839ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 01:47:21 -0400 Subject: [PATCH 1/6] Implement global average pooling for controlnet. --- comfy/sd.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index fa7bd8d32..336fee4a6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): return torch.cat([tensor] * batched_number, dim=0) class ControlNet: - def __init__(self, control_model, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None @@ -630,6 +630,7 @@ class ControlNet: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None @@ -665,6 +666,9 @@ class ControlNet: key = 'output' index = i x = control[i] + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) @@ -695,7 +699,7 @@ class ControlNet: self.cond_hint = None def copy(self): - c = ControlNet(self.control_model) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c.cond_hint_original = self.cond_hint_original c.strength = self.strength return c @@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None): if use_fp16: control_model = control_model.half() - control = ControlNet(control_model) + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control class T2IAdapter: From 0a5fefd6213e3116359e0738533a9e3b733506c5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:05:37 -0400 Subject: [PATCH 2/6] Cleanups and fixes for model_management.py Hopefully fix regression on MPS and CPU. --- comfy/model_management.py | 63 ++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c3147d76..a492ca6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,16 +4,22 @@ from comfy.cli_args import args import torch class VRAMState(Enum): - CPU = 0 + DISABLED = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 @@ -40,15 +46,25 @@ try: except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + def get_torch_device(): global xpu_available global directml_enabled + global cpu_state if directml_enabled: global directml_device return directml_device - if vram_state == VRAMState.MPS: + if cpu_state == CPUState.MPS: return torch.device("mps") - if vram_state == VRAMState.CPU: + if cpu_state == CPUState.CPU: return torch.device("cpu") else: if xpu_available: @@ -143,8 +159,6 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - - if lowvram_available: try: import accelerate @@ -157,17 +171,15 @@ if lowvram_available: lowvram_available = False -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") + def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": @@ -241,13 +253,9 @@ def load_model_gpu(model): current_loaded_model = model - if vram_set_state == VRAMState.CPU: + if vram_set_state == VRAMState.DISABLED: pass - elif vram_set_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: @@ -263,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -308,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -380,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -417,8 +426,8 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - global vram_state - if vram_state == VRAMState.MPS: + global cpu_state + if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif xpu_available: torch.xpu.empty_cache() From 32f282c861eabcee42fdec702b96ebc8924c9834 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:19:10 -0400 Subject: [PATCH 3/6] Search box style fix. --- web/style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/web/style.css b/web/style.css index db82887c3..47571a16e 100644 --- a/web/style.css +++ b/web/style.css @@ -336,6 +336,7 @@ button.comfy-queue-btn { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; overflow: hidden; + display: block; } .litegraph.litesearchbox input, From c092ffcc18f0a44c062fe914ebda05b29bdcfbc0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:46:52 -0400 Subject: [PATCH 4/6] Latest litegraph from upstream. --- web/lib/litegraph.core.js | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 95f4a2735..908ed5f16 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -8099,11 +8099,15 @@ LGraphNode.prototype.executeAction = function(action) bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR; hovercolor = hovercolor || "#555"; textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR; - var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title - var pos = this.mouse; - var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); - pos = this.last_click_position; - var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); + var pos = this.ds.convertOffsetToCanvas(this.graph_mouse); + var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); + pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null; + if(pos) { + var rect = this.canvas.getBoundingClientRect(); + pos[0] -= rect.left; + pos[1] -= rect.top; + } + var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); ctx.fillStyle = hover ? hovercolor : bgcolor; if(clicked) From 0764bb5218ea49fdeeaebbfc10c6f5b87a8bc879 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:47:20 -0400 Subject: [PATCH 5/6] Move node properties panel from double click to menu option. --- web/lib/litegraph.core.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 908ed5f16..a60848d77 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action) if (this.onShowNodePanel) { this.onShowNodePanel(n); } - else - { - this.showShowNodePanel(n); - } if (this.onNodeDblClicked) { this.onNodeDblClicked(n); @@ -13071,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action) has_submenu: true, callback: LGraphCanvas.onShowMenuNodeProperties }, + { + content: "Properties Panel", + callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) } + }, null, { content: "Title", From 126b4050dc34daabca51c236bfb5cc31dd48056d Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 4 Jun 2023 01:25:49 +0900 Subject: [PATCH 6/6] Crash fix for intermittent crashes that occur when opening MaskEditor. (#732) --- web/extensions/core/maskeditor.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 4b0c12747..6cb3a5385 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog { imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); // update mask - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCanvas.width = drawWidth; maskCanvas.height = drawHeight; maskCanvas.style.top = imgCanvas.offsetTop + "px"; maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); });