mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Merge branch 'Main' into feature/settings
This commit is contained in:
commit
4ca2871848
@ -1,18 +1,25 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import torch
|
||||
|
||||
class VRAMState(Enum):
|
||||
CPU = 0
|
||||
DISABLED = 0
|
||||
NO_VRAM = 1
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
MPS = 5
|
||||
SHARED = 5
|
||||
|
||||
class CPUState(Enum):
|
||||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
set_vram_to = VRAMState.NORMAL_VRAM
|
||||
cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
@ -33,28 +40,77 @@ if args.directml is not None:
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import torch
|
||||
if directml_enabled:
|
||||
pass #TODO
|
||||
else:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||
except:
|
||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
if not args.normalvram and not args.cpu:
|
||||
if lowvram_available and total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
cpu_state = CPUState.MPS
|
||||
except:
|
||||
pass
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if cpu_state == CPUState.MPS:
|
||||
return torch.device("mps")
|
||||
if cpu_state == CPUState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_total_torch = mem_total
|
||||
elif xpu_available:
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_cuda
|
||||
|
||||
if torch_total_too:
|
||||
return (mem_total, mem_total_torch)
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
if not args.normalvram and not args.cpu:
|
||||
if lowvram_available and total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
try:
|
||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||
except:
|
||||
@ -103,8 +159,6 @@ if args.force_fp32:
|
||||
print("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
|
||||
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
@ -117,40 +171,26 @@ if lowvram_available:
|
||||
lowvram_available = False
|
||||
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
vram_state = VRAMState.MPS
|
||||
except:
|
||||
pass
|
||||
if cpu_state != CPUState.GPU:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
if args.cpu:
|
||||
vram_state = VRAMState.CPU
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
print(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
return "{}".format(device.type)
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
if device.type == "cuda":
|
||||
return "{} {}".format(device, torch.cuda.get_device_name(device))
|
||||
else:
|
||||
return "{}".format(device.type)
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
try:
|
||||
print("Using device:", get_torch_device_name(get_torch_device()))
|
||||
print("Device:", get_torch_device_name(get_torch_device()))
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
@ -213,13 +253,9 @@ def load_model_gpu(model):
|
||||
|
||||
current_loaded_model = model
|
||||
|
||||
if vram_set_state == VRAMState.CPU:
|
||||
if vram_set_state == VRAMState.DISABLED:
|
||||
pass
|
||||
elif vram_set_state == VRAMState.MPS:
|
||||
mps_device = torch.device("mps")
|
||||
real_model.to(mps_device)
|
||||
pass
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM:
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(get_torch_device())
|
||||
else:
|
||||
@ -235,7 +271,7 @@ def load_model_gpu(model):
|
||||
def load_controlnet_gpu(control_models):
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
if vram_state == VRAMState.CPU:
|
||||
if vram_state == VRAMState.DISABLED:
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
@ -280,7 +316,8 @@ def get_autocast_device(dev):
|
||||
def xformers_enabled():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if vram_state == VRAMState.CPU:
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if xpu_available:
|
||||
return False
|
||||
@ -352,12 +389,12 @@ def maximum_batch_area():
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
global vram_state
|
||||
return vram_state == VRAMState.CPU
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
||||
def mps_mode():
|
||||
global vram_state
|
||||
return vram_state == VRAMState.MPS
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def should_use_fp16():
|
||||
global xpu_available
|
||||
@ -389,8 +426,8 @@ def should_use_fp16():
|
||||
|
||||
def soft_empty_cache():
|
||||
global xpu_available
|
||||
global vram_state
|
||||
if vram_state == VRAMState.MPS:
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif xpu_available:
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
14
comfy/sd.py
14
comfy/sd.py
@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model, device=None):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
@ -630,6 +630,7 @@ class ControlNet:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||
control_prev = None
|
||||
@ -665,6 +666,9 @@ class ControlNet:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control[i]
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
x = x.to(output_dtype)
|
||||
@ -695,7 +699,7 @@ class ControlNet:
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
return c
|
||||
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
|
||||
control = ControlNet(control_model)
|
||||
global_average_pooling = False
|
||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||
return control
|
||||
|
||||
class T2IAdapter:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||
@ -154,7 +155,7 @@ def get_filename_list_(folder_name):
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return (sorted(list(output_list)), output_folders)
|
||||
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
||||
|
||||
def cached_filename_list_(folder_name):
|
||||
global filename_list_cache
|
||||
@ -162,6 +163,8 @@ def cached_filename_list_(folder_name):
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
if time.perf_counter() < (out[2] + 0.5):
|
||||
return out
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
folder = x
|
||||
@ -170,8 +173,9 @@ def cached_filename_list_(folder_name):
|
||||
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
if x not in out[1]:
|
||||
return None
|
||||
if os.path.isdir(x):
|
||||
if x not in out[1]:
|
||||
return None
|
||||
|
||||
return out
|
||||
|
||||
|
||||
24
server.py
24
server.py
@ -23,6 +23,7 @@ except ImportError:
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
@ -280,6 +281,27 @@ class PromptServer():
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@routes.get("/system_stats")
|
||||
async def get_queue(request):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
system_stats = {
|
||||
"devices": [
|
||||
{
|
||||
"name": device_name,
|
||||
"type": device.type,
|
||||
"index": device.index,
|
||||
"vram_total": vram_total,
|
||||
"vram_free": vram_free,
|
||||
"torch_vram_total": torch_vram_total,
|
||||
"torch_vram_free": torch_vram_free,
|
||||
}
|
||||
]
|
||||
}
|
||||
return web.json_response(system_stats)
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
@ -363,7 +385,7 @@ class PromptServer():
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, settings))
|
||||
return web.json_response({"prompt_id": prompt_id})
|
||||
return web.json_response({"prompt_id": prompt_id, "number": number})
|
||||
else:
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
|
||||
@ -336,6 +336,7 @@ button.comfy-queue-btn {
|
||||
z-index: 9999 !important;
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
overflow: hidden;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.litegraph.litesearchbox input,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user