diff --git a/app/model_manager.py b/app/model_manager.py index f124d1117..09a80f58b 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -5,8 +5,10 @@ import base64 import json import time import logging +import requests import folder_paths import glob +from tqdm.auto import tqdm import comfy.utils from aiohttp import web from PIL import Image @@ -15,8 +17,9 @@ from folder_paths import map_legacy, filter_files_extensions, filter_files_conte class ModelFileManager: - def __init__(self) -> None: + def __init__(self, is_download_model_enabled: lambda: bool= lambda: False) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} + self.is_download_model_enabled = is_download_model_enabled def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -76,6 +79,45 @@ class ModelFileManager: except: return web.Response(status=404) + @routes.post("/download_model") + async def post_download_model(request): + if not self.is_download_model_enabled(): + logging.error("Download Model endpoint is disabled") + return web.Response(status=403) + json_data = await request.json() + url = json_data.get("url", None) + if url is None: + logging.error("URL is not provided") + return web.Response(status=401) + save_dir = json_data.get("save_dir", None) + if save_dir not in folder_paths.folder_names_and_paths: + logging.error("Save directory is not valid") + return web.Response(status=401) + filename = json_data.get("filename", url.split("/")[-1]) + token = json_data.get("token", None) + + save_path = os.path.join(folder_paths.folder_names_and_paths[save_dir][0][0], filename) + tmp_path = save_path + ".tmp" + headers = {"Authorization": f"Bearer {token}"} if token else {} + try: + with requests.get(url, headers=headers,stream=True,timeout=10) as r: + r.raise_for_status() + total_size = int(r.headers.get('content-length', 0)) + with open(tmp_path, "wb") as f: + with tqdm(total=total_size, unit='iB', unit_scale=True, desc=filename) as pbar: + for chunk in r.iter_content(chunk_size=1024*1024): + if not chunk: + break + size = f.write(chunk) + pbar.update(size) + os.rename(tmp_path, save_path) + return web.Response(status=200) + except Exception as e: + logging.error(f"Failed to download model: {e}") + if os.path.exists(tmp_path): + os.remove(tmp_path) + return web.Response(status=500) + def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) folders = folder_paths.folder_names_and_paths[folder_name] diff --git a/server.py b/server.py index 2300393b2..932e79e68 100644 --- a/server.py +++ b/server.py @@ -201,7 +201,7 @@ class PromptServer(): mimetypes.add_type('image/webp', '.webp') self.user_manager = UserManager() - self.model_file_manager = ModelFileManager() + self.model_file_manager = ModelFileManager(is_download_model_enabled=lambda: self.user_manager.settings.get_settings(None).get("Comfy.ModelDownloadEnabled", False)) self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() self.internal_routes = InternalRoutes(self)