mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
Merge branch 'master' into reroute-shortcut
This commit is contained in:
commit
62d31d215c
@ -4,16 +4,22 @@ from comfy.cli_args import args
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
CPU = 0
|
DISABLED = 0
|
||||||
NO_VRAM = 1
|
NO_VRAM = 1
|
||||||
LOW_VRAM = 2
|
LOW_VRAM = 2
|
||||||
NORMAL_VRAM = 3
|
NORMAL_VRAM = 3
|
||||||
HIGH_VRAM = 4
|
HIGH_VRAM = 4
|
||||||
MPS = 5
|
SHARED = 5
|
||||||
|
|
||||||
|
class CPUState(Enum):
|
||||||
|
GPU = 0
|
||||||
|
CPU = 1
|
||||||
|
MPS = 2
|
||||||
|
|
||||||
# Determine VRAM State
|
# Determine VRAM State
|
||||||
vram_state = VRAMState.NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
set_vram_to = VRAMState.NORMAL_VRAM
|
set_vram_to = VRAMState.NORMAL_VRAM
|
||||||
|
cpu_state = CPUState.GPU
|
||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
@ -40,15 +46,25 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
cpu_state = CPUState.MPS
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
global cpu_state
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
global directml_device
|
global directml_device
|
||||||
return directml_device
|
return directml_device
|
||||||
if vram_state == VRAMState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == VRAMState.CPU:
|
if cpu_state == CPUState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
if xpu_available:
|
if xpu_available:
|
||||||
@ -143,8 +159,6 @@ if args.force_fp32:
|
|||||||
print("Forcing FP32, if this improves things please report it.")
|
print("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
try:
|
try:
|
||||||
import accelerate
|
import accelerate
|
||||||
@ -157,17 +171,15 @@ if lowvram_available:
|
|||||||
lowvram_available = False
|
lowvram_available = False
|
||||||
|
|
||||||
|
|
||||||
try:
|
if cpu_state != CPUState.GPU:
|
||||||
if torch.backends.mps.is_available():
|
vram_state = VRAMState.DISABLED
|
||||||
vram_state = VRAMState.MPS
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if args.cpu:
|
if cpu_state == CPUState.MPS:
|
||||||
vram_state = VRAMState.CPU
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
@ -241,13 +253,9 @@ def load_model_gpu(model):
|
|||||||
|
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
|
|
||||||
if vram_set_state == VRAMState.CPU:
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.MPS:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
mps_device = torch.device("mps")
|
|
||||||
real_model.to(mps_device)
|
|
||||||
pass
|
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM:
|
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(get_torch_device())
|
real_model.to(get_torch_device())
|
||||||
else:
|
else:
|
||||||
@ -263,7 +271,7 @@ def load_model_gpu(model):
|
|||||||
def load_controlnet_gpu(control_models):
|
def load_controlnet_gpu(control_models):
|
||||||
global current_gpu_controlnets
|
global current_gpu_controlnets
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.DISABLED:
|
||||||
return
|
return
|
||||||
|
|
||||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
@ -308,7 +316,8 @@ def get_autocast_device(dev):
|
|||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if vram_state == VRAMState.CPU:
|
global cpu_state
|
||||||
|
if cpu_state != CPUState.GPU:
|
||||||
return False
|
return False
|
||||||
if xpu_available:
|
if xpu_available:
|
||||||
return False
|
return False
|
||||||
@ -380,12 +389,12 @@ def maximum_batch_area():
|
|||||||
return int(max(area, 0))
|
return int(max(area, 0))
|
||||||
|
|
||||||
def cpu_mode():
|
def cpu_mode():
|
||||||
global vram_state
|
global cpu_state
|
||||||
return vram_state == VRAMState.CPU
|
return cpu_state == CPUState.CPU
|
||||||
|
|
||||||
def mps_mode():
|
def mps_mode():
|
||||||
global vram_state
|
global cpu_state
|
||||||
return vram_state == VRAMState.MPS
|
return cpu_state == CPUState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
@ -417,8 +426,8 @@ def should_use_fp16():
|
|||||||
|
|
||||||
def soft_empty_cache():
|
def soft_empty_cache():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global vram_state
|
global cpu_state
|
||||||
if vram_state == VRAMState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
|||||||
14
comfy/sd.py
14
comfy/sd.py
@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
class ControlNet:
|
class ControlNet:
|
||||||
def __init__(self, control_model, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
@ -630,6 +630,7 @@ class ControlNet:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -665,6 +666,9 @@ class ControlNet:
|
|||||||
key = 'output'
|
key = 'output'
|
||||||
index = i
|
index = i
|
||||||
x = control[i]
|
x = control[i]
|
||||||
|
if self.global_average_pooling:
|
||||||
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype and not autocast_enabled:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
@ -695,7 +699,7 @@ class ControlNet:
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
return c
|
return c
|
||||||
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if use_fp16:
|
if use_fp16:
|
||||||
control_model = control_model.half()
|
control_model = control_model.half()
|
||||||
|
|
||||||
control = ControlNet(control_model)
|
global_average_pooling = False
|
||||||
|
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
global_average_pooling = True
|
||||||
|
|
||||||
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter:
|
class T2IAdapter:
|
||||||
|
|||||||
@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||||
|
|
||||||
// update mask
|
// update mask
|
||||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
|
||||||
maskCanvas.width = drawWidth;
|
maskCanvas.width = drawWidth;
|
||||||
maskCanvas.height = drawHeight;
|
maskCanvas.height = drawHeight;
|
||||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||||
|
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
if (this.onShowNodePanel) {
|
if (this.onShowNodePanel) {
|
||||||
this.onShowNodePanel(n);
|
this.onShowNodePanel(n);
|
||||||
}
|
}
|
||||||
else
|
|
||||||
{
|
|
||||||
this.showShowNodePanel(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.onNodeDblClicked) {
|
if (this.onNodeDblClicked) {
|
||||||
this.onNodeDblClicked(n);
|
this.onNodeDblClicked(n);
|
||||||
@ -8099,11 +8095,15 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR;
|
bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR;
|
||||||
hovercolor = hovercolor || "#555";
|
hovercolor = hovercolor || "#555";
|
||||||
textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR;
|
textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR;
|
||||||
var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title
|
var pos = this.ds.convertOffsetToCanvas(this.graph_mouse);
|
||||||
var pos = this.mouse;
|
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
|
||||||
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h );
|
pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null;
|
||||||
pos = this.last_click_position;
|
if(pos) {
|
||||||
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h );
|
var rect = this.canvas.getBoundingClientRect();
|
||||||
|
pos[0] -= rect.left;
|
||||||
|
pos[1] -= rect.top;
|
||||||
|
}
|
||||||
|
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
|
||||||
|
|
||||||
ctx.fillStyle = hover ? hovercolor : bgcolor;
|
ctx.fillStyle = hover ? hovercolor : bgcolor;
|
||||||
if(clicked)
|
if(clicked)
|
||||||
@ -13078,6 +13078,10 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
has_submenu: true,
|
has_submenu: true,
|
||||||
callback: LGraphCanvas.onShowMenuNodeProperties
|
callback: LGraphCanvas.onShowMenuNodeProperties
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
content: "Properties Panel",
|
||||||
|
callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) }
|
||||||
|
},
|
||||||
null,
|
null,
|
||||||
{
|
{
|
||||||
content: "Title",
|
content: "Title",
|
||||||
|
|||||||
@ -336,6 +336,7 @@ button.comfy-queue-btn {
|
|||||||
z-index: 9999 !important;
|
z-index: 9999 !important;
|
||||||
background-color: var(--comfy-menu-bg) !important;
|
background-color: var(--comfy-menu-bg) !important;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
|
display: block;
|
||||||
}
|
}
|
||||||
|
|
||||||
.litegraph.litesearchbox input,
|
.litegraph.litesearchbox input,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user