mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
Compare commits
6 Commits
e561f71bc4
...
73a32cc261
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73a32cc261 | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
b65c1b1580 | ||
|
|
389b3325d1 |
@ -5,8 +5,10 @@ import base64
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import folder_paths
|
||||
import glob
|
||||
from tqdm.auto import tqdm
|
||||
import comfy.utils
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
@ -15,8 +17,9 @@ from folder_paths import map_legacy, filter_files_extensions, filter_files_conte
|
||||
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, is_download_model_enabled: lambda: bool= lambda: False) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
self.is_download_model_enabled = is_download_model_enabled
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
@ -76,6 +79,45 @@ class ModelFileManager:
|
||||
except:
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.post("/download_model")
|
||||
async def post_download_model(request):
|
||||
if not self.is_download_model_enabled():
|
||||
logging.error("Download Model endpoint is disabled")
|
||||
return web.Response(status=403)
|
||||
json_data = await request.json()
|
||||
url = json_data.get("url", None)
|
||||
if url is None:
|
||||
logging.error("URL is not provided")
|
||||
return web.Response(status=401)
|
||||
save_dir = json_data.get("save_dir", None)
|
||||
if save_dir not in folder_paths.folder_names_and_paths:
|
||||
logging.error("Save directory is not valid")
|
||||
return web.Response(status=401)
|
||||
filename = json_data.get("filename", url.split("/")[-1])
|
||||
token = json_data.get("token", None)
|
||||
|
||||
save_path = os.path.join(folder_paths.folder_names_and_paths[save_dir][0][0], filename)
|
||||
tmp_path = save_path + ".tmp"
|
||||
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||
try:
|
||||
with requests.get(url, headers=headers,stream=True,timeout=10) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
with open(tmp_path, "wb") as f:
|
||||
with tqdm(total=total_size, unit='iB', unit_scale=True, desc=filename) as pbar:
|
||||
for chunk in r.iter_content(chunk_size=1024*1024):
|
||||
if not chunk:
|
||||
break
|
||||
size = f.write(chunk)
|
||||
pbar.update(size)
|
||||
os.rename(tmp_path, save_path)
|
||||
return web.Response(status=200)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to download model: {e}")
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return web.Response(status=500)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
@ -19,7 +19,8 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@ -746,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||
with torch.inference_mode():
|
||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||
#thread local. So exploit that to escape context
|
||||
if enables_dynamic_vram():
|
||||
t = threading.Thread(
|
||||
target=load_models_gpu_thread,
|
||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
else:
|
||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@ -1112,11 +1133,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
torch.cuda.empty_cache()
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
@ -1132,9 +1153,7 @@ def reset_cast_buffers():
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.empty_cache()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
@ -1284,7 +1303,7 @@ def discard_cuda_async_error():
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
@ -1688,6 +1707,12 @@ def lora_compute_dtype(device):
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@ -1713,9 +1738,6 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None)
|
||||
self.partially_unload(None, 1e32)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
|
||||
@ -201,7 +201,7 @@ class PromptServer():
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.model_file_manager = ModelFileManager(is_download_model_enabled=lambda: self.user_manager.settings.get_settings(None).get("Comfy.ModelDownloadEnabled", False))
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user