Merge remote-tracking branch 'upstream/master' into addBatchIndex

This commit is contained in:
flyingshutter 2023-06-03 21:03:39 +02:00
commit e185db848d
7 changed files with 154 additions and 78 deletions

View File

@ -1,18 +1,25 @@
import psutil import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import torch
class VRAMState(Enum): class VRAMState(Enum):
CPU = 0 DISABLED = 0
NO_VRAM = 1 NO_VRAM = 1
LOW_VRAM = 2 LOW_VRAM = 2
NORMAL_VRAM = 3 NORMAL_VRAM = 3
HIGH_VRAM = 4 HIGH_VRAM = 4
MPS = 5 SHARED = 5
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Determine VRAM State # Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM
cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
@ -33,28 +40,77 @@ if args.directml is not None:
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import torch import intel_extension_for_pytorch as ipex
if directml_enabled: if torch.xpu.is_available():
pass #TODO xpu_available = True
else:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
except: except:
pass pass
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
except:
pass
if args.cpu:
cpu_state = CPUState.CPU
def get_torch_device():
global xpu_available
global directml_enabled
global cpu_state
if directml_enabled:
global directml_device
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except: except:
@ -103,8 +159,6 @@ if args.force_fp32:
print("Forcing FP32, if this improves things please report it.") print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True FORCE_FP32 = True
if lowvram_available: if lowvram_available:
try: try:
import accelerate import accelerate
@ -117,40 +171,26 @@ if lowvram_available:
lowvram_available = False lowvram_available = False
try: if cpu_state != CPUState.GPU:
if torch.backends.mps.is_available(): vram_state = VRAMState.DISABLED
vram_state = VRAMState.MPS
except:
pass
if args.cpu: if cpu_state == CPUState.MPS:
vram_state = VRAMState.CPU vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
return "{}".format(device.type) if device.type == "cuda":
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "{} {}".format(device, torch.cuda.get_device_name(device))
else:
return "{}".format(device.type)
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
print("Using device:", get_torch_device_name(get_torch_device())) print("Device:", get_torch_device_name(get_torch_device()))
except: except:
print("Could not pick default device.") print("Could not pick default device.")
@ -213,13 +253,9 @@ def load_model_gpu(model):
current_loaded_model = model current_loaded_model = model
if vram_set_state == VRAMState.CPU: if vram_set_state == VRAMState.DISABLED:
pass pass
elif vram_set_state == VRAMState.MPS: elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.to(get_torch_device()) real_model.to(get_torch_device())
else: else:
@ -235,7 +271,7 @@ def load_model_gpu(model):
def load_controlnet_gpu(control_models): def load_controlnet_gpu(control_models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == VRAMState.CPU: if vram_state == VRAMState.DISABLED:
return return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
@ -280,7 +316,8 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if vram_state == VRAMState.CPU: global cpu_state
if cpu_state != CPUState.GPU:
return False return False
if xpu_available: if xpu_available:
return False return False
@ -352,12 +389,12 @@ def maximum_batch_area():
return int(max(area, 0)) return int(max(area, 0))
def cpu_mode(): def cpu_mode():
global vram_state global cpu_state
return vram_state == VRAMState.CPU return cpu_state == CPUState.CPU
def mps_mode(): def mps_mode():
global vram_state global cpu_state
return vram_state == VRAMState.MPS return cpu_state == CPUState.MPS
def should_use_fp16(): def should_use_fp16():
global xpu_available global xpu_available
@ -389,8 +426,8 @@ def should_use_fp16():
def soft_empty_cache(): def soft_empty_cache():
global xpu_available global xpu_available
global vram_state global cpu_state
if vram_state == VRAMState.MPS: if cpu_state == CPUState.MPS:
torch.mps.empty_cache() torch.mps.empty_cache()
elif xpu_available: elif xpu_available:
torch.xpu.empty_cache() torch.xpu.empty_cache()

View File

@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class ControlNet: class ControlNet:
def __init__(self, control_model, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model self.control_model = control_model
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
@ -630,6 +630,7 @@ class ControlNet:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number): def get_control(self, x_noisy, t, cond_txt, batched_number):
control_prev = None control_prev = None
@ -665,6 +666,9 @@ class ControlNet:
key = 'output' key = 'output'
index = i index = i
x = control[i] x = control[i]
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
@ -695,7 +699,7 @@ class ControlNet:
self.cond_hint = None self.cond_hint = None
def copy(self): def copy(self):
c = ControlNet(self.control_model) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
return c return c
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
if use_fp16: if use_fp16:
control_model = control_model.half() control_model = control_model.half()
control = ControlNet(control_model) global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control return control
class T2IAdapter: class T2IAdapter:

View File

@ -1,4 +1,5 @@
import os import os
import time
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
@ -154,7 +155,7 @@ def get_filename_list_(folder_name):
output_list.update(filter_files_extensions(files, folders[1])) output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all} output_folders = {**output_folders, **folders_all}
return (sorted(list(output_list)), output_folders) return (sorted(list(output_list)), output_folders, time.perf_counter())
def cached_filename_list_(folder_name): def cached_filename_list_(folder_name):
global filename_list_cache global filename_list_cache
@ -162,6 +163,8 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]: for x in out[1]:
time_modified = out[1][x] time_modified = out[1][x]
folder = x folder = x
@ -170,8 +173,9 @@ def cached_filename_list_(folder_name):
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
for x in folders[0]: for x in folders[0]:
if x not in out[1]: if os.path.isdir(x):
return None if x not in out[1]:
return None
return out return out

View File

@ -23,6 +23,7 @@ except ImportError:
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
@ -280,6 +281,27 @@ class PromptServer():
return web.Response(status=404) return web.Response(status=404)
return web.json_response(dt["__metadata__"]) return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
}
return web.json_response(system_stats)
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())

View File

@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog {
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask // update mask
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCanvas.width = drawWidth; maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight; maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px"; maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px"; maskCanvas.style.left = imgCanvas.offsetLeft + "px";
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
}); });

View File

@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action)
if (this.onShowNodePanel) { if (this.onShowNodePanel) {
this.onShowNodePanel(n); this.onShowNodePanel(n);
} }
else
{
this.showShowNodePanel(n);
}
if (this.onNodeDblClicked) { if (this.onNodeDblClicked) {
this.onNodeDblClicked(n); this.onNodeDblClicked(n);
@ -8099,11 +8095,15 @@ LGraphNode.prototype.executeAction = function(action)
bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR; bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR;
hovercolor = hovercolor || "#555"; hovercolor = hovercolor || "#555";
textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR; textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR;
var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title var pos = this.ds.convertOffsetToCanvas(this.graph_mouse);
var pos = this.mouse; var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null;
pos = this.last_click_position; if(pos) {
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); var rect = this.canvas.getBoundingClientRect();
pos[0] -= rect.left;
pos[1] -= rect.top;
}
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
ctx.fillStyle = hover ? hovercolor : bgcolor; ctx.fillStyle = hover ? hovercolor : bgcolor;
if(clicked) if(clicked)
@ -13067,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action)
has_submenu: true, has_submenu: true,
callback: LGraphCanvas.onShowMenuNodeProperties callback: LGraphCanvas.onShowMenuNodeProperties
}, },
{
content: "Properties Panel",
callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) }
},
null, null,
{ {
content: "Title", content: "Title",

View File

@ -336,6 +336,7 @@ button.comfy-queue-btn {
z-index: 9999 !important; z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important; background-color: var(--comfy-menu-bg) !important;
overflow: hidden; overflow: hidden;
display: block;
} }
.litegraph.litesearchbox input, .litegraph.litesearchbox input,