Track model download progress for remote UI polling

This commit is contained in:
djmango 2026-06-14 09:55:10 -07:00
parent 8c3211fdf1
commit bbfb237477

View File

@ -3,7 +3,9 @@ import base64
import json import json
import time import time
import logging import logging
import asyncio
import requests import requests
from threading import Lock
from tqdm.auto import tqdm from tqdm.auto import tqdm
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
from typing import Callable from typing import Callable
@ -20,6 +22,8 @@ class ModelFileManager:
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None: def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.is_download_model_enabled = is_download_model_enabled self.is_download_model_enabled = is_download_model_enabled
self._download_progress: dict[str, dict] = {}
self._download_lock = Lock()
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default) return self.cache.get(key, default)
@ -30,6 +34,39 @@ class ModelFileManager:
def clear_cache(self): def clear_cache(self):
self.cache.clear() self.cache.clear()
def _download_progress_key(self, save_dir: str, filename: str) -> str:
return f"{save_dir}/{filename}"
def _set_download_progress(self, key: str, **fields) -> None:
with self._download_lock:
entry = dict(self._download_progress.get(key, {}))
entry.update(fields)
self._download_progress[key] = entry
def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None:
with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0) or 0)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
self._set_download_progress(
key,
status="running",
bytes_downloaded=0,
bytes_total=total_size,
filename=os.path.basename(save_path),
)
with open(tmp_path, "wb") as f:
downloaded = 0
for chunk in r.iter_content(chunk_size=1024 * 1024):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded)
os.replace(tmp_path, save_path)
def add_routes(self, routes): def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models` # NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models") @routes.get("/experiment/models")
@ -94,28 +131,61 @@ class ModelFileManager:
token = json_data.get("token") token = json_data.get("token")
headers = {"Authorization": f"Bearer {token}"} if token else {} headers = {"Authorization": f"Bearer {token}"} if token else {}
key = self._download_progress_key(save_dir, filename)
loop = asyncio.get_running_loop()
try: try:
with requests.get(url, headers=headers, stream=True, timeout=60) as r: await loop.run_in_executor(
r.raise_for_status() None,
total_size = int(r.headers.get("content-length", 0)) self._download_file_sync,
os.makedirs(os.path.dirname(save_path), exist_ok=True) url,
with open(tmp_path, "wb") as f: headers,
with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar: tmp_path,
for chunk in r.iter_content(chunk_size=1024 * 1024): save_path,
if not chunk: key,
break )
size = f.write(chunk) self._set_download_progress(key, status="completed")
pbar.update(size)
os.replace(tmp_path, save_path)
logging.info("Downloaded model to %s", save_path) logging.info("Downloaded model to %s", save_path)
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename}) return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
except Exception as e: except Exception as e:
logging.error("Failed to download model: %s", e) logging.error("Failed to download model: %s", e)
self._set_download_progress(key, status="failed", error=str(e))
if os.path.exists(tmp_path): if os.path.exists(tmp_path):
os.remove(tmp_path) os.remove(tmp_path)
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": str(e)}, status=500)
@routes.get("/download_model/progress")
async def get_download_model_progress(request):
save_dir = request.rel_url.query.get("save_dir")
filename = request.rel_url.query.get("filename")
if not save_dir or not filename:
return web.json_response({"error": "save_dir and filename required"}, status=400)
key = self._download_progress_key(save_dir, filename)
with self._download_lock:
entry = dict(self._download_progress.get(key, {}))
if not entry and save_dir in folder_paths.folder_names_and_paths:
root = folder_paths.folder_names_and_paths[save_dir][0][0]
tmp_path = os.path.join(root, filename + ".tmp")
final_path = os.path.join(root, filename)
if os.path.isfile(final_path):
size = os.path.getsize(final_path)
entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size}
elif os.path.isfile(tmp_path):
size = os.path.getsize(tmp_path)
entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0}
bytes_total = int(entry.get("bytes_total") or 0)
bytes_downloaded = int(entry.get("bytes_downloaded") or 0)
progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0
return web.json_response({
"status": entry.get("status", "unknown"),
"bytes_downloaded": bytes_downloaded,
"bytes_total": bytes_total,
"progress": progress,
"error": entry.get("error"),
})
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request): async def get_model_preview(request):
folder_name = request.match_info.get("folder", None) folder_name = request.match_info.get("folder", None)