diff --git a/app/model_manager.py b/app/model_manager.py index 074b59213..084112cea 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -3,7 +3,9 @@ import base64 import json import time import logging +import asyncio import requests +from threading import Lock from tqdm.auto import tqdm from urllib.parse import unquote, urlparse from typing import Callable @@ -20,6 +22,8 @@ class ModelFileManager: def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.is_download_model_enabled = is_download_model_enabled + self._download_progress: dict[str, dict] = {} + self._download_lock = Lock() def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -30,6 +34,39 @@ class ModelFileManager: def clear_cache(self): self.cache.clear() + + def _download_progress_key(self, save_dir: str, filename: str) -> str: + return f"{save_dir}/{filename}" + + def _set_download_progress(self, key: str, **fields) -> None: + with self._download_lock: + entry = dict(self._download_progress.get(key, {})) + entry.update(fields) + self._download_progress[key] = entry + + def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None: + with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r: + r.raise_for_status() + total_size = int(r.headers.get("content-length", 0) or 0) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + self._set_download_progress( + key, + status="running", + bytes_downloaded=0, + bytes_total=total_size, + filename=os.path.basename(save_path), + ) + with open(tmp_path, "wb") as f: + downloaded = 0 + for chunk in r.iter_content(chunk_size=1024 * 1024): + if not chunk: + continue + f.write(chunk) + downloaded += len(chunk) + self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded) + os.replace(tmp_path, save_path) + + def add_routes(self, routes): # NOTE: This is an experiment to replace `/models` @routes.get("/experiment/models") @@ -94,28 +131,61 @@ class ModelFileManager: token = json_data.get("token") headers = {"Authorization": f"Bearer {token}"} if token else {} + key = self._download_progress_key(save_dir, filename) + loop = asyncio.get_running_loop() try: - with requests.get(url, headers=headers, stream=True, timeout=60) as r: - r.raise_for_status() - total_size = int(r.headers.get("content-length", 0)) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - with open(tmp_path, "wb") as f: - with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar: - for chunk in r.iter_content(chunk_size=1024 * 1024): - if not chunk: - break - size = f.write(chunk) - pbar.update(size) - os.replace(tmp_path, save_path) + await loop.run_in_executor( + None, + self._download_file_sync, + url, + headers, + tmp_path, + save_path, + key, + ) + self._set_download_progress(key, status="completed") logging.info("Downloaded model to %s", save_path) return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename}) except Exception as e: logging.error("Failed to download model: %s", e) + self._set_download_progress(key, status="failed", error=str(e)) if os.path.exists(tmp_path): os.remove(tmp_path) return web.json_response({"error": str(e)}, status=500) + + @routes.get("/download_model/progress") + async def get_download_model_progress(request): + save_dir = request.rel_url.query.get("save_dir") + filename = request.rel_url.query.get("filename") + if not save_dir or not filename: + return web.json_response({"error": "save_dir and filename required"}, status=400) + key = self._download_progress_key(save_dir, filename) + with self._download_lock: + entry = dict(self._download_progress.get(key, {})) + if not entry and save_dir in folder_paths.folder_names_and_paths: + root = folder_paths.folder_names_and_paths[save_dir][0][0] + tmp_path = os.path.join(root, filename + ".tmp") + final_path = os.path.join(root, filename) + if os.path.isfile(final_path): + size = os.path.getsize(final_path) + entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size} + elif os.path.isfile(tmp_path): + size = os.path.getsize(tmp_path) + entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0} + bytes_total = int(entry.get("bytes_total") or 0) + bytes_downloaded = int(entry.get("bytes_downloaded") or 0) + progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0 + return web.json_response({ + "status": entry.get("status", "unknown"), + "bytes_downloaded": bytes_downloaded, + "bytes_total": bytes_total, + "progress": progress, + "error": entry.get("error"), + }) + + @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") async def get_model_preview(request): folder_name = request.match_info.get("folder", None)