mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 16:59:29 +08:00
Track model download progress for remote UI polling
This commit is contained in:
parent
8c3211fdf1
commit
bbfb237477
@ -3,7 +3,9 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
import requests
|
import requests
|
||||||
|
from threading import Lock
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from urllib.parse import unquote, urlparse
|
from urllib.parse import unquote, urlparse
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
@ -20,6 +22,8 @@ class ModelFileManager:
|
|||||||
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
|
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
|
||||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||||
self.is_download_model_enabled = is_download_model_enabled
|
self.is_download_model_enabled = is_download_model_enabled
|
||||||
|
self._download_progress: dict[str, dict] = {}
|
||||||
|
self._download_lock = Lock()
|
||||||
|
|
||||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||||
return self.cache.get(key, default)
|
return self.cache.get(key, default)
|
||||||
@ -30,6 +34,39 @@ class ModelFileManager:
|
|||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _download_progress_key(self, save_dir: str, filename: str) -> str:
|
||||||
|
return f"{save_dir}/{filename}"
|
||||||
|
|
||||||
|
def _set_download_progress(self, key: str, **fields) -> None:
|
||||||
|
with self._download_lock:
|
||||||
|
entry = dict(self._download_progress.get(key, {}))
|
||||||
|
entry.update(fields)
|
||||||
|
self._download_progress[key] = entry
|
||||||
|
|
||||||
|
def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None:
|
||||||
|
with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
total_size = int(r.headers.get("content-length", 0) or 0)
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
self._set_download_progress(
|
||||||
|
key,
|
||||||
|
status="running",
|
||||||
|
bytes_downloaded=0,
|
||||||
|
bytes_total=total_size,
|
||||||
|
filename=os.path.basename(save_path),
|
||||||
|
)
|
||||||
|
with open(tmp_path, "wb") as f:
|
||||||
|
downloaded = 0
|
||||||
|
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded)
|
||||||
|
os.replace(tmp_path, save_path)
|
||||||
|
|
||||||
|
|
||||||
def add_routes(self, routes):
|
def add_routes(self, routes):
|
||||||
# NOTE: This is an experiment to replace `/models`
|
# NOTE: This is an experiment to replace `/models`
|
||||||
@routes.get("/experiment/models")
|
@routes.get("/experiment/models")
|
||||||
@ -94,28 +131,61 @@ class ModelFileManager:
|
|||||||
token = json_data.get("token")
|
token = json_data.get("token")
|
||||||
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||||
|
|
||||||
|
key = self._download_progress_key(save_dir, filename)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
with requests.get(url, headers=headers, stream=True, timeout=60) as r:
|
await loop.run_in_executor(
|
||||||
r.raise_for_status()
|
None,
|
||||||
total_size = int(r.headers.get("content-length", 0))
|
self._download_file_sync,
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
url,
|
||||||
with open(tmp_path, "wb") as f:
|
headers,
|
||||||
with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar:
|
tmp_path,
|
||||||
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
save_path,
|
||||||
if not chunk:
|
key,
|
||||||
break
|
)
|
||||||
size = f.write(chunk)
|
self._set_download_progress(key, status="completed")
|
||||||
pbar.update(size)
|
|
||||||
os.replace(tmp_path, save_path)
|
|
||||||
logging.info("Downloaded model to %s", save_path)
|
logging.info("Downloaded model to %s", save_path)
|
||||||
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
|
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to download model: %s", e)
|
logging.error("Failed to download model: %s", e)
|
||||||
|
self._set_download_progress(key, status="failed", error=str(e))
|
||||||
if os.path.exists(tmp_path):
|
if os.path.exists(tmp_path):
|
||||||
os.remove(tmp_path)
|
os.remove(tmp_path)
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@routes.get("/download_model/progress")
|
||||||
|
async def get_download_model_progress(request):
|
||||||
|
save_dir = request.rel_url.query.get("save_dir")
|
||||||
|
filename = request.rel_url.query.get("filename")
|
||||||
|
if not save_dir or not filename:
|
||||||
|
return web.json_response({"error": "save_dir and filename required"}, status=400)
|
||||||
|
key = self._download_progress_key(save_dir, filename)
|
||||||
|
with self._download_lock:
|
||||||
|
entry = dict(self._download_progress.get(key, {}))
|
||||||
|
if not entry and save_dir in folder_paths.folder_names_and_paths:
|
||||||
|
root = folder_paths.folder_names_and_paths[save_dir][0][0]
|
||||||
|
tmp_path = os.path.join(root, filename + ".tmp")
|
||||||
|
final_path = os.path.join(root, filename)
|
||||||
|
if os.path.isfile(final_path):
|
||||||
|
size = os.path.getsize(final_path)
|
||||||
|
entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size}
|
||||||
|
elif os.path.isfile(tmp_path):
|
||||||
|
size = os.path.getsize(tmp_path)
|
||||||
|
entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0}
|
||||||
|
bytes_total = int(entry.get("bytes_total") or 0)
|
||||||
|
bytes_downloaded = int(entry.get("bytes_downloaded") or 0)
|
||||||
|
progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0
|
||||||
|
return web.json_response({
|
||||||
|
"status": entry.get("status", "unknown"),
|
||||||
|
"bytes_downloaded": bytes_downloaded,
|
||||||
|
"bytes_total": bytes_total,
|
||||||
|
"progress": progress,
|
||||||
|
"error": entry.get("error"),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||||
async def get_model_preview(request):
|
async def get_model_preview(request):
|
||||||
folder_name = request.match_info.get("folder", None)
|
folder_name = request.match_info.get("folder", None)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user