diff --git a/app/model_manager.py b/app/model_manager.py index 8f6e34b33..074b59213 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -3,6 +3,10 @@ import base64 import json import time import logging +import requests +from tqdm.auto import tqdm +from urllib.parse import unquote, urlparse +from typing import Callable import folder_paths import glob import comfy.utils @@ -13,8 +17,9 @@ from folder_paths import map_legacy, filter_files_extensions, filter_files_conte class ModelFileManager: - def __init__(self) -> None: + def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} + self.is_download_model_enabled = is_download_model_enabled def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -47,6 +52,70 @@ class ModelFileManager: files = self.get_model_file_list(folder) return web.json_response(files) + + @routes.post("/download_model") + async def post_download_model(request): + if not self.is_download_model_enabled(): + logging.error("Download Model endpoint is disabled") + return web.Response(status=403) + + json_data = await request.json() + url = json_data.get("url") + if not url: + return web.json_response({"error": "url required"}, status=400) + + save_dir = json_data.get("save_dir") + if save_dir not in folder_paths.folder_names_and_paths: + return web.json_response({"error": "invalid save_dir"}, status=400) + + default_filename = unquote(urlparse(url).path.split("/")[-1].split("?")[0]) + filename = json_data.get("filename") or default_filename + if not filename or filename in (".", "..") or "/" in filename or "\\" in filename: + return web.json_response({"error": "invalid filename"}, status=400) + + allowed_sources = ( + "https://civitai.com/", + "https://civitai.red/", + "https://huggingface.co/", + "https://github.com/", + "http://localhost:", + ) + if not any(url.startswith(src) for src in allowed_sources): + return web.json_response({"error": "url not allowed"}, status=400) + + save_root = folder_paths.folder_names_and_paths[save_dir][0][0] + save_path = os.path.join(save_root, filename) + save_real = os.path.realpath(save_path) + root_real = os.path.realpath(save_root) + if not save_real.startswith(root_real + os.sep) and save_real != root_real: + return web.json_response({"error": "invalid path"}, status=400) + + tmp_path = save_path + ".tmp" + token = json_data.get("token") + headers = {"Authorization": f"Bearer {token}"} if token else {} + + try: + with requests.get(url, headers=headers, stream=True, timeout=60) as r: + r.raise_for_status() + total_size = int(r.headers.get("content-length", 0)) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(tmp_path, "wb") as f: + with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar: + for chunk in r.iter_content(chunk_size=1024 * 1024): + if not chunk: + break + size = f.write(chunk) + pbar.update(size) + os.replace(tmp_path, save_path) + logging.info("Downloaded model to %s", save_path) + return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename}) + except Exception as e: + logging.error("Failed to download model: %s", e) + if os.path.exists(tmp_path): + os.remove(tmp_path) + return web.json_response({"error": str(e)}, status=500) + + @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") async def get_model_preview(request): folder_name = request.match_info.get("folder", None) diff --git a/server.py b/server.py index 6b0029adf..db8b9ac71 100644 --- a/server.py +++ b/server.py @@ -206,7 +206,21 @@ class PromptServer(): PromptServer.instance = self self.user_manager = UserManager() - self.model_file_manager = ModelFileManager() + def _is_model_download_enabled(): + settings_path = os.path.join(folder_paths.get_user_directory(), "default", "comfy.settings.json") + try: + if os.path.isfile(settings_path): + with open(settings_path) as f: + settings = json.load(f) + return settings.get("Comfy.ModelDownloadEnabled", True) + except Exception: + pass + return True + + self.model_file_manager = ModelFileManager(is_download_model_enabled=_is_model_download_enabled) + if hasattr(self.user_manager.settings, "get_settings") + else True + ) self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() self.node_replace_manager = NodeReplaceManager()