mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Track model download progress for remote UI polling
This commit is contained in:
parent
8c3211fdf1
commit
bbfb237477
@ -3,7 +3,9 @@ import base64
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
import requests
|
||||
from threading import Lock
|
||||
from tqdm.auto import tqdm
|
||||
from urllib.parse import unquote, urlparse
|
||||
from typing import Callable
|
||||
@ -20,6 +22,8 @@ class ModelFileManager:
|
||||
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
self.is_download_model_enabled = is_download_model_enabled
|
||||
self._download_progress: dict[str, dict] = {}
|
||||
self._download_lock = Lock()
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
@ -30,6 +34,39 @@ class ModelFileManager:
|
||||
def clear_cache(self):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def _download_progress_key(self, save_dir: str, filename: str) -> str:
|
||||
return f"{save_dir}/{filename}"
|
||||
|
||||
def _set_download_progress(self, key: str, **fields) -> None:
|
||||
with self._download_lock:
|
||||
entry = dict(self._download_progress.get(key, {}))
|
||||
entry.update(fields)
|
||||
self._download_progress[key] = entry
|
||||
|
||||
def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None:
|
||||
with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get("content-length", 0) or 0)
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
self._set_download_progress(
|
||||
key,
|
||||
status="running",
|
||||
bytes_downloaded=0,
|
||||
bytes_total=total_size,
|
||||
filename=os.path.basename(save_path),
|
||||
)
|
||||
with open(tmp_path, "wb") as f:
|
||||
downloaded = 0
|
||||
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded)
|
||||
os.replace(tmp_path, save_path)
|
||||
|
||||
|
||||
def add_routes(self, routes):
|
||||
# NOTE: This is an experiment to replace `/models`
|
||||
@routes.get("/experiment/models")
|
||||
@ -94,28 +131,61 @@ class ModelFileManager:
|
||||
token = json_data.get("token")
|
||||
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||
|
||||
key = self._download_progress_key(save_dir, filename)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
with requests.get(url, headers=headers, stream=True, timeout=60) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(tmp_path, "wb") as f:
|
||||
with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar:
|
||||
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
||||
if not chunk:
|
||||
break
|
||||
size = f.write(chunk)
|
||||
pbar.update(size)
|
||||
os.replace(tmp_path, save_path)
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self._download_file_sync,
|
||||
url,
|
||||
headers,
|
||||
tmp_path,
|
||||
save_path,
|
||||
key,
|
||||
)
|
||||
self._set_download_progress(key, status="completed")
|
||||
logging.info("Downloaded model to %s", save_path)
|
||||
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
|
||||
except Exception as e:
|
||||
logging.error("Failed to download model: %s", e)
|
||||
self._set_download_progress(key, status="failed", error=str(e))
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
|
||||
|
||||
@routes.get("/download_model/progress")
|
||||
async def get_download_model_progress(request):
|
||||
save_dir = request.rel_url.query.get("save_dir")
|
||||
filename = request.rel_url.query.get("filename")
|
||||
if not save_dir or not filename:
|
||||
return web.json_response({"error": "save_dir and filename required"}, status=400)
|
||||
key = self._download_progress_key(save_dir, filename)
|
||||
with self._download_lock:
|
||||
entry = dict(self._download_progress.get(key, {}))
|
||||
if not entry and save_dir in folder_paths.folder_names_and_paths:
|
||||
root = folder_paths.folder_names_and_paths[save_dir][0][0]
|
||||
tmp_path = os.path.join(root, filename + ".tmp")
|
||||
final_path = os.path.join(root, filename)
|
||||
if os.path.isfile(final_path):
|
||||
size = os.path.getsize(final_path)
|
||||
entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size}
|
||||
elif os.path.isfile(tmp_path):
|
||||
size = os.path.getsize(tmp_path)
|
||||
entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0}
|
||||
bytes_total = int(entry.get("bytes_total") or 0)
|
||||
bytes_downloaded = int(entry.get("bytes_downloaded") or 0)
|
||||
progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0
|
||||
return web.json_response({
|
||||
"status": entry.get("status", "unknown"),
|
||||
"bytes_downloaded": bytes_downloaded,
|
||||
"bytes_total": bytes_total,
|
||||
"progress": progress,
|
||||
"error": entry.get("error"),
|
||||
})
|
||||
|
||||
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user