mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
add upload asset endpoint
This commit is contained in:
parent
09dabf95bc
commit
a763cbd39d
@ -97,3 +97,36 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
|||||||
|
|
||||||
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||||
|
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||||
|
root = tags[0]
|
||||||
|
if root == "models":
|
||||||
|
if len(tags) < 2:
|
||||||
|
raise ValueError("at least two tags required for model asset")
|
||||||
|
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||||
|
if not bases:
|
||||||
|
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||||
|
base_dir = os.path.abspath(bases[0])
|
||||||
|
raw_subdirs = tags[2:]
|
||||||
|
else:
|
||||||
|
base_dir = os.path.abspath(
|
||||||
|
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||||
|
)
|
||||||
|
raw_subdirs = tags[1:]
|
||||||
|
for i in raw_subdirs:
|
||||||
|
if i in (".", ".."):
|
||||||
|
raise ValueError("invalid path component in tags")
|
||||||
|
|
||||||
|
return base_dir, raw_subdirs if raw_subdirs else []
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_within_base(candidate: str, base: str) -> None:
|
||||||
|
cand_abs = os.path.abspath(candidate)
|
||||||
|
base_abs = os.path.abspath(base)
|
||||||
|
try:
|
||||||
|
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||||
|
raise ValueError("destination escapes base directory")
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("invalid destination path")
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
from .. import assets_manager, assets_scanner
|
from .. import assets_manager, assets_scanner
|
||||||
from . import schemas_in
|
from . import schemas_in
|
||||||
|
|
||||||
@ -42,7 +46,6 @@ async def list_assets(request: web.Request) -> web.Response:
|
|||||||
return web.json_response(payload.model_dump(mode="json"))
|
return web.json_response(payload.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get("/api/assets/{id}/content")
|
@ROUTES.get("/api/assets/{id}/content")
|
||||||
async def download_asset_content(request: web.Request) -> web.Response:
|
async def download_asset_content(request: web.Request) -> web.Response:
|
||||||
asset_info_id_raw = request.match_info.get("id")
|
asset_info_id_raw = request.match_info.get("id")
|
||||||
@ -75,6 +78,118 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets/from-hash")
|
||||||
|
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
result = await assets_manager.create_asset_from_hash(
|
||||||
|
hash_str=body.hash,
|
||||||
|
name=body.name,
|
||||||
|
tags=body.tags,
|
||||||
|
user_metadata=body.user_metadata,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets")
|
||||||
|
async def upload_asset(request: web.Request) -> web.Response:
|
||||||
|
"""Multipart/form-data endpoint for Asset uploads."""
|
||||||
|
|
||||||
|
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||||
|
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||||
|
|
||||||
|
reader = await request.multipart()
|
||||||
|
|
||||||
|
file_field = None
|
||||||
|
file_client_name: Optional[str] = None
|
||||||
|
tags_raw: list[str] = []
|
||||||
|
provided_name: Optional[str] = None
|
||||||
|
user_metadata_raw: Optional[str] = None
|
||||||
|
file_written = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
fname = getattr(field, "name", None) or ""
|
||||||
|
if fname == "file":
|
||||||
|
# Save to temp
|
||||||
|
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||||
|
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||||
|
os.makedirs(unique_dir, exist_ok=True)
|
||||||
|
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||||
|
|
||||||
|
file_field = field
|
||||||
|
file_client_name = (field.filename or "").strip()
|
||||||
|
try:
|
||||||
|
with open(tmp_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
file_written += len(chunk)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||||
|
elif fname == "tags":
|
||||||
|
tags_raw.append((await field.text()) or "")
|
||||||
|
elif fname == "name":
|
||||||
|
provided_name = (await field.text()) or None
|
||||||
|
elif fname == "user_metadata":
|
||||||
|
user_metadata_raw = (await field.text()) or None
|
||||||
|
|
||||||
|
if file_field is None:
|
||||||
|
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part.")
|
||||||
|
|
||||||
|
if file_written == 0:
|
||||||
|
try:
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||||
|
"tags": tags_raw,
|
||||||
|
"name": provided_name,
|
||||||
|
"user_metadata": user_metadata_raw,
|
||||||
|
})
|
||||||
|
except ValidationError as ve:
|
||||||
|
try:
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
|
||||||
|
if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||||
|
return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
created = await assets_manager.upload_asset_from_temp_path(
|
||||||
|
spec,
|
||||||
|
temp_path=tmp_path,
|
||||||
|
client_filename=file_client_name,
|
||||||
|
)
|
||||||
|
return web.json_response(created.model_dump(mode="json"), status=201)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.put("/api/assets/{id}")
|
@ROUTES.put("/api/assets/{id}")
|
||||||
async def update_asset(request: web.Request) -> web.Response:
|
async def update_asset(request: web.Request) -> web.Response:
|
||||||
asset_info_id_raw = request.match_info.get("id")
|
asset_info_id_raw = request.match_info.get("id")
|
||||||
@ -104,27 +219,6 @@ async def update_asset(request: web.Request) -> web.Response:
|
|||||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets/from-hash")
|
|
||||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
|
||||||
try:
|
|
||||||
payload = await request.json()
|
|
||||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
|
||||||
except ValidationError as ve:
|
|
||||||
return _validation_error_response("INVALID_BODY", ve)
|
|
||||||
except Exception:
|
|
||||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
|
||||||
|
|
||||||
result = await assets_manager.create_asset_from_hash(
|
|
||||||
hash_str=body.hash,
|
|
||||||
name=body.name,
|
|
||||||
tags=body.tags,
|
|
||||||
user_metadata=body.user_metadata,
|
|
||||||
)
|
|
||||||
if result is None:
|
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
|
||||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete("/api/assets/{id}")
|
@ROUTES.delete("/api/assets/{id}")
|
||||||
async def delete_asset(request: web.Request) -> web.Response:
|
async def delete_asset(request: web.Request) -> web.Response:
|
||||||
asset_info_id_raw = request.match_info.get("id")
|
asset_info_id_raw = request.match_info.get("id")
|
||||||
|
|||||||
@ -172,3 +172,93 @@ class ScheduleAssetScanBody(BaseModel):
|
|||||||
out.append(r)
|
out.append(r)
|
||||||
seen.add(r)
|
seen.add(r)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UploadAssetSpec(BaseModel):
|
||||||
|
"""Upload Asset operation.
|
||||||
|
- tags: ordered; first is root ('models'|'input'|'output');
|
||||||
|
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||||
|
- name: desired filename (optional); fallback will be the file hash
|
||||||
|
- user_metadata: arbitrary JSON object (optional)
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
tags: list[str] = Field(..., min_length=1)
|
||||||
|
name: Optional[str] = Field(default=None, max_length=512)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@field_validator("tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_tags(cls, v):
|
||||||
|
"""
|
||||||
|
Accepts a list of strings (possibly multiple form fields),
|
||||||
|
where each string can be:
|
||||||
|
- JSON array (e.g., '["models","loras","foo"]')
|
||||||
|
- comma-separated ('models, loras, foo')
|
||||||
|
- single token ('models')
|
||||||
|
Returns a normalized, deduplicated, ordered list.
|
||||||
|
"""
|
||||||
|
items: list[str] = []
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
v = [v]
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
for item in v:
|
||||||
|
if item is None:
|
||||||
|
continue
|
||||||
|
s = str(item).strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
if s.startswith("["):
|
||||||
|
try:
|
||||||
|
arr = json.loads(s)
|
||||||
|
if isinstance(arr, list):
|
||||||
|
items.extend(str(x) for x in arr)
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass # fallback to CSV parse below
|
||||||
|
items.extend([p for p in s.split(",") if p.strip()])
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# normalize + dedupe
|
||||||
|
norm = []
|
||||||
|
seen = set()
|
||||||
|
for t in items:
|
||||||
|
tnorm = str(t).strip().lower()
|
||||||
|
if tnorm and tnorm not in seen:
|
||||||
|
seen.add(tnorm)
|
||||||
|
norm.append(tnorm)
|
||||||
|
return norm
|
||||||
|
|
||||||
|
@field_validator("user_metadata", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_metadata_json(cls, v):
|
||||||
|
if v is None or isinstance(v, dict):
|
||||||
|
return v or {}
|
||||||
|
if isinstance(v, str):
|
||||||
|
s = v.strip()
|
||||||
|
if not s:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(s)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("user_metadata must be a JSON object")
|
||||||
|
return parsed
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_order(self):
|
||||||
|
if not self.tags:
|
||||||
|
raise ValueError("tags must be provided and non-empty")
|
||||||
|
root = self.tags[0]
|
||||||
|
if root not in {"models", "input", "output"}:
|
||||||
|
raise ValueError("first tag must be one of: models, input, output")
|
||||||
|
if root == "models":
|
||||||
|
if len(self.tags) < 2:
|
||||||
|
raise ValueError("models uploads require a category tag as the second tag")
|
||||||
|
return self
|
||||||
|
|||||||
@ -25,8 +25,8 @@ from .database.services import (
|
|||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
create_asset_info_for_existing_asset,
|
create_asset_info_for_existing_asset,
|
||||||
)
|
)
|
||||||
from .api import schemas_out
|
from .api import schemas_in, schemas_out
|
||||||
from ._assets_helpers import get_name_and_tags_from_asset_path
|
from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base
|
||||||
|
|
||||||
|
|
||||||
async def asset_exists(*, asset_hash: str) -> bool:
|
async def asset_exists(*, asset_hash: str) -> bool:
|
||||||
@ -173,6 +173,97 @@ async def resolve_asset_content_for_download(
|
|||||||
return abs_path, ctype, download_name
|
return abs_path, ctype, download_name
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_asset_from_temp_path(
|
||||||
|
spec: schemas_in.UploadAssetSpec,
|
||||||
|
*,
|
||||||
|
temp_path: str,
|
||||||
|
client_filename: Optional[str] = None,
|
||||||
|
) -> schemas_out.AssetCreated:
|
||||||
|
"""
|
||||||
|
Finalize an uploaded temp file:
|
||||||
|
- compute blake3 hash
|
||||||
|
- resolve destination from tags
|
||||||
|
- decide filename (spec.name or client filename or hash)
|
||||||
|
- move file atomically
|
||||||
|
- ingest into DB (assets, locator state, asset_info + tags)
|
||||||
|
Returns a populated AssetCreated payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
digest = await hashing.blake3_hash(temp_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||||
|
asset_hash = "blake3:" + digest
|
||||||
|
|
||||||
|
# Resolve destination
|
||||||
|
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||||
|
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||||
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Decide filename
|
||||||
|
desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||||
|
dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name))
|
||||||
|
ensure_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
|
# Content type based on final name
|
||||||
|
content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream"
|
||||||
|
|
||||||
|
# Atomic move into place
|
||||||
|
try:
|
||||||
|
os.replace(temp_path, dest_abs)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||||
|
|
||||||
|
# Stat final file
|
||||||
|
try:
|
||||||
|
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||||
|
|
||||||
|
# Ingest + build response
|
||||||
|
async with await create_session() as session:
|
||||||
|
result = await ingest_fs_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
abs_path=dest_abs,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
mime_type=content_type,
|
||||||
|
info_name=os.path.basename(dest_abs),
|
||||||
|
owner_id="",
|
||||||
|
preview_hash=None,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
tags=spec.tags,
|
||||||
|
tag_origin="manual",
|
||||||
|
added_by=None,
|
||||||
|
require_existing_tags=False,
|
||||||
|
)
|
||||||
|
info_id = result.get("asset_info_id")
|
||||||
|
if not info_id:
|
||||||
|
raise RuntimeError("failed to create asset metadata")
|
||||||
|
|
||||||
|
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id))
|
||||||
|
if not pair:
|
||||||
|
raise RuntimeError("inconsistent DB state after ingest")
|
||||||
|
info, asset = pair
|
||||||
|
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=info.asset_hash,
|
||||||
|
size=int(asset.size_bytes),
|
||||||
|
mime_type=asset.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_hash=info.preview_hash,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_asset(
|
async def update_asset(
|
||||||
*,
|
*,
|
||||||
asset_info_id: int,
|
asset_info_id: int,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user