mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 01:09:24 +08:00
Add server-side model download for remote web UI.
Expose POST /download_model so browser clients fetch models onto the host instead of the user's laptop. Enabled by default via Comfy.ModelDownloadEnabled.
This commit is contained in:
parent
e1b9366898
commit
e95e06ac1f
@ -3,6 +3,10 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import requests
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
from typing import Callable
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import glob
|
import glob
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -13,8 +17,9 @@ from folder_paths import map_legacy, filter_files_extensions, filter_files_conte
|
|||||||
|
|
||||||
|
|
||||||
class ModelFileManager:
|
class ModelFileManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
|
||||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||||
|
self.is_download_model_enabled = is_download_model_enabled
|
||||||
|
|
||||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||||
return self.cache.get(key, default)
|
return self.cache.get(key, default)
|
||||||
@ -47,6 +52,70 @@ class ModelFileManager:
|
|||||||
files = self.get_model_file_list(folder)
|
files = self.get_model_file_list(folder)
|
||||||
return web.json_response(files)
|
return web.json_response(files)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.post("/download_model")
|
||||||
|
async def post_download_model(request):
|
||||||
|
if not self.is_download_model_enabled():
|
||||||
|
logging.error("Download Model endpoint is disabled")
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
json_data = await request.json()
|
||||||
|
url = json_data.get("url")
|
||||||
|
if not url:
|
||||||
|
return web.json_response({"error": "url required"}, status=400)
|
||||||
|
|
||||||
|
save_dir = json_data.get("save_dir")
|
||||||
|
if save_dir not in folder_paths.folder_names_and_paths:
|
||||||
|
return web.json_response({"error": "invalid save_dir"}, status=400)
|
||||||
|
|
||||||
|
default_filename = unquote(urlparse(url).path.split("/")[-1].split("?")[0])
|
||||||
|
filename = json_data.get("filename") or default_filename
|
||||||
|
if not filename or filename in (".", "..") or "/" in filename or "\\" in filename:
|
||||||
|
return web.json_response({"error": "invalid filename"}, status=400)
|
||||||
|
|
||||||
|
allowed_sources = (
|
||||||
|
"https://civitai.com/",
|
||||||
|
"https://civitai.red/",
|
||||||
|
"https://huggingface.co/",
|
||||||
|
"https://github.com/",
|
||||||
|
"http://localhost:",
|
||||||
|
)
|
||||||
|
if not any(url.startswith(src) for src in allowed_sources):
|
||||||
|
return web.json_response({"error": "url not allowed"}, status=400)
|
||||||
|
|
||||||
|
save_root = folder_paths.folder_names_and_paths[save_dir][0][0]
|
||||||
|
save_path = os.path.join(save_root, filename)
|
||||||
|
save_real = os.path.realpath(save_path)
|
||||||
|
root_real = os.path.realpath(save_root)
|
||||||
|
if not save_real.startswith(root_real + os.sep) and save_real != root_real:
|
||||||
|
return web.json_response({"error": "invalid path"}, status=400)
|
||||||
|
|
||||||
|
tmp_path = save_path + ".tmp"
|
||||||
|
token = json_data.get("token")
|
||||||
|
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with requests.get(url, headers=headers, stream=True, timeout=60) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
total_size = int(r.headers.get("content-length", 0))
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
with open(tmp_path, "wb") as f:
|
||||||
|
with tqdm(total=total_size, unit="iB", unit_scale=True, desc=filename) as pbar:
|
||||||
|
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
size = f.write(chunk)
|
||||||
|
pbar.update(size)
|
||||||
|
os.replace(tmp_path, save_path)
|
||||||
|
logging.info("Downloaded model to %s", save_path)
|
||||||
|
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Failed to download model: %s", e)
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||||
async def get_model_preview(request):
|
async def get_model_preview(request):
|
||||||
folder_name = request.match_info.get("folder", None)
|
folder_name = request.match_info.get("folder", None)
|
||||||
|
|||||||
16
server.py
16
server.py
@ -206,7 +206,21 @@ class PromptServer():
|
|||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
|
|
||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
self.model_file_manager = ModelFileManager()
|
def _is_model_download_enabled():
|
||||||
|
settings_path = os.path.join(folder_paths.get_user_directory(), "default", "comfy.settings.json")
|
||||||
|
try:
|
||||||
|
if os.path.isfile(settings_path):
|
||||||
|
with open(settings_path) as f:
|
||||||
|
settings = json.load(f)
|
||||||
|
return settings.get("Comfy.ModelDownloadEnabled", True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.model_file_manager = ModelFileManager(is_download_model_enabled=_is_model_download_enabled)
|
||||||
|
if hasattr(self.user_manager.settings, "get_settings")
|
||||||
|
else True
|
||||||
|
)
|
||||||
self.custom_node_manager = CustomNodeManager()
|
self.custom_node_manager = CustomNodeManager()
|
||||||
self.subgraph_manager = SubgraphManager()
|
self.subgraph_manager = SubgraphManager()
|
||||||
self.node_replace_manager = NodeReplaceManager()
|
self.node_replace_manager = NodeReplaceManager()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user