Track model download progress for remote UI polling

This commit is contained in:
djmango 2026-06-14 09:55:10 -07:00
parent 8c3211fdf1
commit bbfb237477

View File

@ -3,7 +3,9 @@ import base64
import json
import time
import logging
import asyncio
import requests
from threading import Lock
from tqdm.auto import tqdm
from urllib.parse import unquote, urlparse
from typing import Callable
@ -20,6 +22,8 @@ class ModelFileManager:
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.is_download_model_enabled = is_download_model_enabled
self._download_progress: dict[str, dict] = {}
self._download_lock = Lock()
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
@ -30,6 +34,39 @@ class ModelFileManager:
def clear_cache(self):
self.cache.clear()
def _download_progress_key(self, save_dir: str, filename: str) -> str:
return f"{save_dir}/{filename}"
def _set_download_progress(self, key: str, **fields) -> None:
with self._download_lock:
entry = dict(self._download_progress.get(key, {}))
entry.update(fields)
self._download_progress[key] = entry
def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None:
with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0) or 0)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
self._set_download_progress(
key,
status="running",
bytes_downloaded=0,
bytes_total=total_size,
filename=os.path.basename(save_path),
)
with open(tmp_path, "wb") as f:
downloaded = 0
for chunk in r.iter_content(chunk_size=1024 * 1024):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded)
os.replace(tmp_path, save_path)
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
@ -94,28 +131,61 @@ class ModelFileManager:
token = json_data.get("token")
headers = {"Authorization": f"Bearer {token}"} if token else {}
key = self._download_progress_key(save_dir, filename)
loop = asyncio.get_running_loop()
try:
with requests.get(url, headers=headers, stream=True, timeout=60) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(tmp_path, "wb") as f:
with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if not chunk:
break
size = f.write(chunk)
pbar.update(size)
os.replace(tmp_path, save_path)
await loop.run_in_executor(
None,
self._download_file_sync,
url,
headers,
tmp_path,
save_path,
key,
)
self._set_download_progress(key, status="completed")
logging.info("Downloaded model to %s", save_path)
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
except Exception as e:
logging.error("Failed to download model: %s", e)
self._set_download_progress(key, status="failed", error=str(e))
if os.path.exists(tmp_path):
os.remove(tmp_path)
return web.json_response({"error": str(e)}, status=500)
@routes.get("/download_model/progress")
async def get_download_model_progress(request):
save_dir = request.rel_url.query.get("save_dir")
filename = request.rel_url.query.get("filename")
if not save_dir or not filename:
return web.json_response({"error": "save_dir and filename required"}, status=400)
key = self._download_progress_key(save_dir, filename)
with self._download_lock:
entry = dict(self._download_progress.get(key, {}))
if not entry and save_dir in folder_paths.folder_names_and_paths:
root = folder_paths.folder_names_and_paths[save_dir][0][0]
tmp_path = os.path.join(root, filename + ".tmp")
final_path = os.path.join(root, filename)
if os.path.isfile(final_path):
size = os.path.getsize(final_path)
entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size}
elif os.path.isfile(tmp_path):
size = os.path.getsize(tmp_path)
entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0}
bytes_total = int(entry.get("bytes_total") or 0)
bytes_downloaded = int(entry.get("bytes_downloaded") or 0)
progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0
return web.json_response({
"status": entry.get("status", "unknown"),
"bytes_downloaded": bytes_downloaded,
"bytes_total": bytes_total,
"progress": progress,
"error": entry.get("error"),
})
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)