Compare commits

...

6 Commits

Author SHA1 Message Date
yt-koike
73a32cc261
Merge b65c1b1580 into c05a08ae66 2026-02-03 09:52:51 +09:00
comfyanonymous
c05a08ae66
Add back function. (#12234)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-02-02 19:52:07 -05:00
rattus
de9ada6a41
Dynamic VRAM unloading fix (#12227)
* mp: fix full dynamic unloading

This was not unloading dynamic models when requesting a full unload via
the unpatch() code path.

This was ok, i your workflow was all dynamic models but fails with big
VRAM leaks if you need to fully unload something for a regular ModelPatcher

It also fices the "unload models" button.

* mm: load models outside of Aimdo Mempool

In dynamic_vram mode, escape the Aimdo mempool and load into the regular
mempool. Use a dummy thread to do it.
2026-02-02 17:35:20 -05:00
rattus
37f711d4a1
mm: Fix cast buffers with intel offloading (#12229)
Intel has offloading support but there were some nvidia calls in the
new cast buffer stuff.
2026-02-02 17:34:46 -05:00
yt-koike
b65c1b1580 replace print() to logging.error() 2026-01-30 12:14:28 +09:00
yt-koike
389b3325d1 add model downloader endpoint 2026-01-30 11:41:07 +09:00
4 changed files with 78 additions and 14 deletions

View File

@ -5,8 +5,10 @@ import base64
import json
import time
import logging
import requests
import folder_paths
import glob
from tqdm.auto import tqdm
import comfy.utils
from aiohttp import web
from PIL import Image
@ -15,8 +17,9 @@ from folder_paths import map_legacy, filter_files_extensions, filter_files_conte
class ModelFileManager:
def __init__(self) -> None:
def __init__(self, is_download_model_enabled: lambda: bool= lambda: False) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.is_download_model_enabled = is_download_model_enabled
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
@ -76,6 +79,45 @@ class ModelFileManager:
except:
return web.Response(status=404)
@routes.post("/download_model")
async def post_download_model(request):
if not self.is_download_model_enabled():
logging.error("Download Model endpoint is disabled")
return web.Response(status=403)
json_data = await request.json()
url = json_data.get("url", None)
if url is None:
logging.error("URL is not provided")
return web.Response(status=401)
save_dir = json_data.get("save_dir", None)
if save_dir not in folder_paths.folder_names_and_paths:
logging.error("Save directory is not valid")
return web.Response(status=401)
filename = json_data.get("filename", url.split("/")[-1])
token = json_data.get("token", None)
save_path = os.path.join(folder_paths.folder_names_and_paths[save_dir][0][0], filename)
tmp_path = save_path + ".tmp"
headers = {"Authorization": f"Bearer {token}"} if token else {}
try:
with requests.get(url, headers=headers,stream=True,timeout=10) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
with open(tmp_path, "wb") as f:
with tqdm(total=total_size, unit='iB', unit_scale=True, desc=filename) as pbar:
for chunk in r.iter_content(chunk_size=1024*1024):
if not chunk:
break
size = f.write(chunk)
pbar.update(size)
os.rename(tmp_path, save_path)
return web.Response(status=200)
except Exception as e:
logging.error(f"Failed to download model: {e}")
if os.path.exists(tmp_path):
os.remove(tmp_path)
return web.Response(status=500)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]

View File

@ -19,7 +19,8 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
import torch
import sys
import platform
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@ -746,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_loaded_models.insert(0, loaded_model)
return
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def load_model_gpu(model):
return load_models_gpu([model])
@ -1112,11 +1133,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
torch.cuda.synchronize()
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
torch.cuda.empty_cache()
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
@ -1132,9 +1153,7 @@ def reset_cast_buffers():
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
STREAM_CAST_BUFFERS.clear()
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.empty_cache()
soft_empty_cache()
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@ -1284,7 +1303,7 @@ def discard_cuda_async_error():
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
torch.cuda.synchronize()
synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
@ -1688,6 +1707,12 @@ def lora_compute_dtype(device):
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def synchronize():
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@ -1713,9 +1738,6 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass

View File

@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None)
self.partially_unload(None, 1e32)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above

View File

@ -201,7 +201,7 @@ class PromptServer():
mimetypes.add_type('image/webp', '.webp')
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.model_file_manager = ModelFileManager(is_download_model_enabled=lambda: self.user_manager.settings.get_settings(None).get("Comfy.ModelDownloadEnabled", False))
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.internal_routes = InternalRoutes(self)