This commit is contained in:
Simon Pinfold 2026-06-22 22:50:29 +08:00 committed by GitHub
commit ec26ba212c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 856 additions and 176 deletions

View File

@ -0,0 +1,107 @@
"""
Allow case-sensitive tag names.
Revision ID: 0005_allow_case_sensitive_tags
Revises: 0004_drop_tag_type
Create Date: 2026-06-16
"""
import sqlalchemy as sa
from alembic import op
revision = "0005_allow_case_sensitive_tags"
down_revision = "0004_drop_tag_type"
branch_labels = None
depends_on = None
def upgrade() -> None:
bind = op.get_bind()
if bind.dialect.name == "sqlite":
# SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag
# vocabulary table without the lowercase constraint while preserving
# existing tag names.
op.execute("PRAGMA foreign_keys=OFF")
try:
op.execute(
"CREATE TABLE tags_new ("
"name VARCHAR(512) NOT NULL, "
"CONSTRAINT pk_tags PRIMARY KEY (name)"
")"
)
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
op.execute("DROP TABLE tags")
op.execute("ALTER TABLE tags_new RENAME TO tags")
finally:
op.execute("PRAGMA foreign_keys=ON")
return
op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check")
def downgrade() -> None:
# Existing mixed-case tags cannot satisfy the old constraint. Lowercase them
# before restoring it, merging duplicate vocabulary/link rows that collide.
bind = op.get_bind()
tag_names = [row[0] for row in bind.execute(sa.text("SELECT name FROM tags"))]
existing_names = set(tag_names)
lowercase_names = sorted({name.lower() for name in tag_names})
missing_lowercase_rows = [
{"name": name} for name in lowercase_names if name not in existing_names
]
if missing_lowercase_rows:
bind.execute(sa.text("INSERT INTO tags(name) VALUES (:name)"), missing_lowercase_rows)
link_rows = bind.execute(
sa.text(
"SELECT asset_reference_id, tag_name, origin, added_at "
"FROM asset_reference_tags "
"ORDER BY asset_reference_id, tag_name"
)
).mappings()
deduped_links = {}
for row in link_rows:
key = (row["asset_reference_id"], row["tag_name"].lower())
deduped_links.setdefault(
key,
{
"asset_reference_id": row["asset_reference_id"],
"tag_name": row["tag_name"].lower(),
"origin": row["origin"],
"added_at": row["added_at"],
},
)
op.execute("DELETE FROM asset_reference_tags")
if deduped_links:
bind.execute(
sa.text(
"INSERT INTO asset_reference_tags "
"(asset_reference_id, tag_name, origin, added_at) "
"VALUES (:asset_reference_id, :tag_name, :origin, :added_at)"
),
list(deduped_links.values()),
)
op.execute("DELETE FROM tags WHERE name != lower(name)")
if bind.dialect.name == "sqlite":
op.execute("PRAGMA foreign_keys=OFF")
try:
op.execute(
"CREATE TABLE tags_new ("
"name VARCHAR(512) NOT NULL, "
"CONSTRAINT pk_tags PRIMARY KEY (name), "
"CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))"
")"
)
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
op.execute("DROP TABLE tags")
op.execute("ALTER TABLE tags_new RENAME TO tags")
finally:
op.execute("PRAGMA foreign_keys=ON")
return
op.create_check_constraint(
"ck_tags_ck_tags_lowercase", "tags", "name = lower(name)"
)

View File

@ -10,7 +10,6 @@ from typing import Any
from aiohttp import web from aiohttp import web
from pydantic import ValidationError from pydantic import ValidationError
import folder_paths
from app import user_manager from app import user_manager
from app.assets.api import schemas_in, schemas_out from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas from app.assets.services import schemas
@ -408,6 +407,7 @@ async def upload_asset(request: web.Request) -> web.Response:
"hash": parsed.provided_hash, "hash": parsed.provided_hash,
"mime_type": parsed.provided_mime_type, "mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id, "preview_id": parsed.provided_preview_id,
"subfolder": parsed.provided_subfolder,
} }
) )
except ValidationError as ve: except ValidationError as ve:
@ -416,17 +416,6 @@ async def upload_asset(request: web.Request) -> web.Response:
400, "INVALID_BODY", f"Validation failed: {ve.json()}" 400, "INVALID_BODY", f"Validation failed: {ve.json()}"
) )
if spec.tags and spec.tags[0] == "models":
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)
try: try:
# Fast path: hash exists, create AssetReference without writing anything # Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True: if spec.hash and parsed.provided_hash_exists is True:
@ -464,13 +453,14 @@ async def upload_asset(request: web.Request) -> web.Response:
expected_hash=spec.hash, expected_hash=spec.hash,
mime_type=spec.mime_type, mime_type=spec.mime_type,
preview_id=spec.preview_id, preview_id=spec.preview_id,
subfolder=spec.subfolder,
) )
except AssetValidationError as e: except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, e.code, str(e)) return _build_error_response(400, e.code, str(e))
except ValueError as e: except ValueError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "BAD_REQUEST", str(e)) return _build_error_response(400, "INVALID_BODY", str(e))
except HashMismatchError as e: except HashMismatchError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "HASH_MISMATCH", str(e)) return _build_error_response(400, "HASH_MISMATCH", str(e))

View File

@ -47,6 +47,7 @@ class ParsedUpload:
provided_hash_exists: bool | None provided_hash_exists: bool | None
provided_mime_type: str | None = None provided_mime_type: str | None = None
provided_preview_id: str | None = None provided_preview_id: str | None = None
provided_subfolder: str | None = None
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
@ -140,7 +141,7 @@ class CreateFromHashBody(BaseModel):
if v is None: if v is None:
return [] return []
if isinstance(v, list): if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()] out = [str(t).strip() for t in v if str(t).strip()]
seen = set() seen = set()
dedup = [] dedup = []
for t in out: for t in out:
@ -149,7 +150,7 @@ class CreateFromHashBody(BaseModel):
dedup.append(t) dedup.append(t)
return dedup return dedup
if isinstance(v, str): if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()] return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip()))
return [] return []
@ -206,7 +207,7 @@ class TagsListQuery(BaseModel):
if v is None: if v is None:
return v return v
v = v.strip() v = v.strip()
return v.lower() or None return v or None
class TagsAdd(BaseModel): class TagsAdd(BaseModel):
@ -220,7 +221,7 @@ class TagsAdd(BaseModel):
for t in v: for t in v:
if not isinstance(t, str): if not isinstance(t, str):
raise TypeError("tags must be strings") raise TypeError("tags must be strings")
tnorm = t.strip().lower() tnorm = t.strip()
if tnorm: if tnorm:
out.append(tnorm) out.append(tnorm)
seen = set() seen = set()
@ -239,8 +240,9 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel): class UploadAssetSpec(BaseModel):
"""Upload Asset operation. """Upload Asset operation.
- tags: optional list; if provided, first is root ('models'|'input'|'output'); - tags: labels plus one destination role ('models'|'input'|'output') for new bytes;
if root == 'models', second must be a valid category if role == 'models', exactly one model_type:<folder_name> tag is required
- subfolder: optional destination subfolder for new bytes
- name: display name - name: display name
- user_metadata: arbitrary JSON object (optional) - user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path - hash: optional canonical 'blake3:<hex>' for validation / fast-path
@ -258,6 +260,7 @@ class UploadAssetSpec(BaseModel):
hash: str | None = Field(default=None) hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None) mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None) # references an asset_reference id preview_id: str | None = Field(default=None) # references an asset_reference id
subfolder: str | None = Field(default=None, max_length=1024)
@field_validator("hash", mode="before") @field_validator("hash", mode="before")
@classmethod @classmethod
@ -309,12 +312,20 @@ class UploadAssetSpec(BaseModel):
norm = [] norm = []
seen = set() seen = set()
for t in items: for t in items:
tnorm = str(t).strip().lower() tnorm = str(t).strip()
if tnorm and tnorm not in seen: if tnorm and tnorm not in seen:
seen.add(tnorm) seen.add(tnorm)
norm.append(tnorm) norm.append(tnorm)
return norm return norm
@field_validator("subfolder", mode="before")
@classmethod
def _parse_subfolder(cls, v):
if v is None:
return None
s = str(v).strip()
return s or None
@field_validator("user_metadata", mode="before") @field_validator("user_metadata", mode="before")
@classmethod @classmethod
def _parse_metadata_json(cls, v): def _parse_metadata_json(cls, v):
@ -335,14 +346,4 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_order(self): def _validate_order(self):
if not self.tags:
raise ValueError("at least one tag is required for uploads")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self return self

View File

@ -54,6 +54,7 @@ async def parse_multipart_upload(
provided_hash_exists: bool | None = None provided_hash_exists: bool | None = None
provided_mime_type: str | None = None provided_mime_type: str | None = None
provided_preview_id: str | None = None provided_preview_id: str | None = None
provided_subfolder: str | None = None
file_written = 0 file_written = 0
tmp_path: str | None = None tmp_path: str | None = None
@ -140,6 +141,8 @@ async def parse_multipart_upload(
provided_mime_type = ((await field.text()) or "").strip() or None provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id": elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None provided_preview_id = ((await field.text()) or "").strip() or None
elif fname == "subfolder":
provided_subfolder = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists): if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError( raise UploadError(
@ -166,6 +169,7 @@ async def parse_multipart_upload(
provided_hash_exists=provided_hash_exists, provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type, provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id, provided_preview_id=provided_preview_id,
provided_subfolder=provided_subfolder,
) )

View File

@ -265,6 +265,8 @@ def list_tags_with_usage(
order: str = "count_desc", order: str = "count_desc",
owner_id: str = "", owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]: ) -> tuple[list[tuple[str, str, int]], int]:
prefix_filter = prefix.strip() if prefix else ""
counts_sq = ( counts_sq = (
select( select(
AssetReferenceTag.tag_name.label("tag_name"), AssetReferenceTag.tag_name.label("tag_name"),
@ -293,9 +295,8 @@ def list_tags_with_usage(
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
) )
if prefix: if prefix_filter:
escaped, esc = escape_sql_like_string(prefix.strip().lower()) q = q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero: if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
@ -306,9 +307,8 @@ def list_tags_with_usage(
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag) total_q = select(func.count()).select_from(Tag)
if prefix: if prefix_filter:
escaped, esc = escape_sql_like_string(prefix.strip().lower()) total_q = total_q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero: if not include_zero:
visible_tags_sq = ( visible_tags_sq = (
select(AssetReferenceTag.tag_name) select(AssetReferenceTag.tag_name)

View File

@ -41,10 +41,10 @@ def get_utc_now() -> datetime:
def normalize_tags(tags: list[str] | None) -> list[str]: def normalize_tags(tags: list[str] | None) -> list[str]:
""" """
Normalize a list of tags by: Normalize a list of tags by:
- Stripping whitespace and converting to lowercase. - Stripping whitespace.
- Removing duplicates. - Removing exact duplicates while preserving order and case.
""" """
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip()))
def validate_blake3_hash(s: str) -> str: def validate_blake3_hash(s: str) -> str:

View File

@ -34,6 +34,7 @@ from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
get_backend_system_tags_from_path,
get_name_and_tags_from_asset_path, get_name_and_tags_from_asset_path,
resolve_destination_from_tags, resolve_destination_from_tags,
validate_path_within_base, validate_path_within_base,
@ -101,7 +102,11 @@ def _ingest_file_from_path(
if preview_id and ref.preview_id != preview_id: if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id ref.preview_id = preview_id
norm = normalize_tags(list(tags)) try:
backend_tags = get_backend_system_tags_from_path(locator)
except ValueError:
backend_tags = []
norm = normalize_tags([*list(tags), *backend_tags])
if norm: if norm:
if require_existing_tags: if require_existing_tags:
validate_tags_exist(session, norm) validate_tags_exist(session, norm)
@ -458,6 +463,7 @@ def upload_from_temp_path(
expected_hash: str | None = None, expected_hash: str | None = None,
mime_type: str | None = None, mime_type: str | None = None,
preview_id: str | None = None, preview_id: str | None = None,
subfolder: str | None = None,
) -> UploadResult: ) -> UploadResult:
try: try:
digest, _ = hashing.compute_blake3_hash(temp_path) digest, _ = hashing.compute_blake3_hash(temp_path)
@ -474,6 +480,10 @@ def upload_from_temp_path(
existing = get_asset_by_hash(session, asset_hash=asset_hash) existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None: if existing is not None:
# Once content is already known, duplicate byte uploads are treated as
# reference-only creation. Request tags are labels only here: do not
# require upload destination tags, do not move bytes, and do not
# synthesize path-derived classification or uploaded provenance.
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path): if temp_path and os.path.exists(temp_path):
os.remove(temp_path) os.remove(temp_path)
@ -498,7 +508,7 @@ def upload_from_temp_path(
if not tags: if not tags:
raise ValueError("tags are required for new asset uploads") raise ValueError("tags are required for new asset uploads")
base_dir, subdirs = resolve_destination_from_tags(tags) base_dir, subdirs = resolve_destination_from_tags(tags, subfolder=subfolder)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True) os.makedirs(dest_dir, exist_ok=True)
@ -535,7 +545,7 @@ def upload_from_temp_path(
owner_id=owner_id, owner_id=owner_id,
preview_id=preview_id, preview_id=preview_id,
user_metadata=user_metadata or {}, user_metadata=user_metadata or {},
tags=tags, tags=[*(tags or []), "uploaded"],
tag_origin="manual", tag_origin="manual",
require_existing_tags=False, require_existing_tags=False,
) )
@ -569,15 +579,19 @@ def register_file_in_place(
) -> UploadResult: ) -> UploadResult:
"""Register an already-saved file in the asset database without moving it. """Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names), This helper is used by upload paths that have already written bytes before
merged with any caller-provided tags, matching the behavior of the scanner. registering the file, so it records the same ``uploaded`` tag as the
multipart byte-upload path.
Tags are derived from trusted filesystem classification and merged with any
caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used. If the path is not under a known root, only the caller-provided tags are used.
""" """
try: try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path) _, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError: except ValueError:
path_tags = [] path_tags = []
merged_tags = normalize_tags([*path_tags, *tags]) merged_tags = normalize_tags([*path_tags, *tags, "uploaded"])
try: try:
digest, _ = hashing.compute_blake3_hash(abs_path) digest, _ = hashing.compute_blake3_hash(abs_path)

View File

@ -1,12 +1,11 @@
import os import os
from pathlib import Path from pathlib import Path, PureWindowsPath
from typing import Literal from typing import Literal
import folder_paths import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) _NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
def get_comfy_models_folders() -> list[tuple[str, list[str]]]: def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
@ -14,7 +13,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
Includes every category registered in folder_names_and_paths, Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir, regardless of whether its paths are under the main models_dir,
but excludes non-model entries like custom_nodes. but excludes non-model entries like configs and custom_nodes.
""" """
targets: list[tuple[str, list[str]]] = [] targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items(): for name, values in folder_paths.folder_names_and_paths.items():
@ -26,36 +25,60 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
return targets return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: def _validate_subfolder(subfolder: str | None) -> list[str]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)""" if not subfolder:
if not tags: return []
raise ValueError("tags must not be empty")
root = tags[0].lower() if "\\" in subfolder:
raise ValueError("invalid subfolder path")
windows_path = PureWindowsPath(subfolder)
if windows_path.drive or windows_path.root:
raise ValueError("invalid subfolder path")
parts = Path(subfolder).parts
invalid = {"", ".", ".."}
if Path(subfolder).is_absolute() or any(part in invalid for part in parts):
raise ValueError("invalid subfolder path")
if any("/" in part or "\\" in part for part in parts):
raise ValueError("invalid subfolder path")
return list(parts)
def resolve_destination_from_tags(
tags: list[str], subfolder: str | None = None
) -> tuple[str, list[str]]:
"""Validates and maps upload routing tags -> (base_dir, subdirs_for_fs).
The request tags are only used to choose the write destination. Extra tags
remain labels; they do not become path components or trusted classification.
Explicit subfolder is the only request field that can add path components.
"""
destination_roles = [t for t in tags if t in {"input", "models", "output"}]
if len(destination_roles) != 1:
raise ValueError("uploads require exactly one destination role: input, models, or output")
root = destination_roles[0]
if root == "models": if root == "models":
if len(tags) < 2: model_type_tags = [t for t in tags if t.startswith("model_type:")]
raise ValueError("at least two tags required for model asset") if len(model_type_tags) != 1:
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
folder_name = model_type_tags[0].split(":", 1)[1]
if not folder_name:
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
model_folder_paths = dict(get_comfy_models_folders())
try: try:
bases = folder_paths.folder_names_and_paths[tags[1]][0] bases = model_folder_paths[folder_name]
except KeyError: except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'") raise ValueError(f"unknown model category '{folder_name}'")
if not bases: if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'") raise ValueError(f"no base path configured for category '{folder_name}'")
base_dir = os.path.abspath(bases[0]) base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
elif root == "input": elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory()) base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
else: else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") base_dir = os.path.abspath(folder_paths.get_output_directory())
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else [] return base_dir, _validate_subfolder(subfolder)
def validate_path_within_base(candidate: str, base: str) -> None: def validate_path_within_base(candidate: str, base: str) -> None:
@ -156,18 +179,55 @@ def get_asset_category_and_relative_path(
) )
def get_backend_system_tags_from_path(path: str) -> list[str]:
"""Return trusted backend tags derived from current filesystem facts.
The returned tags are only the backend-generated system tags: ``models``,
``model_type:<folder_name>``, ``input``, ``output``, and ``temp``. Model
type tags are based on registered folder names, not path components.
"""
fp_abs = os.path.abspath(path)
fp_path = Path(fp_abs)
tags: list[str] = []
def _add(tag: str) -> None:
if tag not in tags:
tags.append(tag)
for role, base in (
("input", folder_paths.get_input_directory()),
("output", folder_paths.get_output_directory()),
("temp", folder_paths.get_temp_directory()),
):
if fp_path.is_relative_to(os.path.abspath(base)):
_add(role)
model_types: list[str] = []
for folder_name, bases in get_comfy_models_folders():
for base in bases:
if fp_path.is_relative_to(os.path.abspath(base)):
model_types.append(folder_name)
break
if model_types:
_add("models")
for folder_name in model_types:
_add(f"model_type:{folder_name}")
if not tags:
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {path}"
)
return tags
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path. """Return (name, tags) derived from a filesystem path.
- name: base filename with extension - name: base filename with extension
- tags: [root_category] + parent folder names in order - tags: trusted backend classification tags derived from the path
Raises: Raises:
ValueError: path does not belong to any known root. ValueError: path does not belong to any known root.
""" """
root_category, some_path = get_asset_category_and_relative_path(file_path) return Path(file_path).name, get_backend_system_tags_from_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@ -100,6 +100,7 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
# Default server capabilities # Default server capabilities
_CORE_FEATURE_FLAGS: dict[str, Any] = { _CORE_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"supports_model_type_tags": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
"node_replacements": True, "node_replacements": True,

View File

@ -7,14 +7,6 @@ components:
description: Timestamp when the asset was created description: Timestamp when the asset was created
format: date-time format: date-time
type: string type: string
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash: hash:
description: Blake3 hash of the asset content. description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$ pattern: ^blake3:[a-f0-9]{64}$
@ -144,14 +136,6 @@ components:
AssetUpdated: AssetUpdated:
description: Response returned when an existing asset is successfully updated. description: Response returned when an existing asset is successfully updated.
properties: properties:
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash: hash:
description: Blake3 hash of the asset content. description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$ pattern: ^blake3:[a-f0-9]{64}$
@ -2454,6 +2438,9 @@ paths:
supports_preview_metadata: supports_preview_metadata:
description: Whether the server supports preview metadata description: Whether the server supports preview metadata
type: boolean type: boolean
supports_model_type_tags:
description: Whether the server supports namespaced model type asset tags
type: boolean
type: object type: object
description: Success description: Success
headers: headers:

View File

@ -440,7 +440,10 @@ class PromptServer():
if args.enable_assets: if args.enable_assets:
try: try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input" tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag]) tags = [tag]
if subfolder in {"3d", "pasted", "painter", "threed", "webcam"}:
tags.append(subfolder)
result = register_file_in_place(abs_path=filepath, name=filename, tags=tags)
resp["asset"] = { resp["asset"] = {
"id": result.ref.id, "id": result.ref.id,
"name": result.ref.name, "name": result.ref.name,

View File

@ -8,6 +8,7 @@ upgrade/downgrade for 0003+.
""" """
import os import os
import sqlite3
import pytest import pytest
from alembic import command from alembic import command
@ -30,6 +31,12 @@ def _make_config(db_path: str) -> Config:
return cfg return cfg
def _sqlite_path(cfg: Config) -> str:
url = cfg.get_main_option("sqlalchemy.url")
assert url is not None and url.startswith("sqlite:///")
return url.removeprefix("sqlite:///")
@pytest.fixture @pytest.fixture
def migration_db(tmp_path): def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision.""" """Yield an alembic Config pre-upgraded to the baseline revision."""
@ -55,3 +62,26 @@ def test_upgrade_downgrade_cycle(migration_db):
command.upgrade(migration_db, "head") command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE) command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head") command.upgrade(migration_db, "head")
def test_case_sensitive_tags_downgrade_normalizes_existing_tags(migration_db):
"""Downgrading 0005 folds mixed-case tag vocabulary before restoring CHECK."""
command.upgrade(migration_db, "0005_allow_case_sensitive_tags")
db_path = _sqlite_path(migration_db)
with sqlite3.connect(db_path) as conn:
conn.execute("INSERT INTO tags(name) VALUES (?)", ("NewTag",))
conn.execute("INSERT INTO tags(name) VALUES (?)", ("newtag",))
conn.execute("INSERT INTO tags(name) VALUES (?)", ("model_type:LLM",))
command.downgrade(migration_db, "0004_drop_tag_type")
with sqlite3.connect(db_path) as conn:
tags = {row[0] for row in conn.execute("SELECT name FROM tags")}
assert "newtag" in tags
assert "model_type:llm" in tags
assert "NewTag" not in tags
assert "model_type:LLM" not in tags
with pytest.raises(sqlite3.IntegrityError):
conn.execute("INSERT INTO tags(name) VALUES (?)", ("Upper",))

View File

@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
p = getattr(request, "param", {}) or {} p = getattr(request, "param", {}) or {}
tags: Optional[list[str]] = p.get("tags") tags: Optional[list[str]] = p.get("tags")
if tags is None: if tags is None:
tags = ["models", "checkpoints", "unit-tests", "alpha"] tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
# Unique content per test so the seed always creates a fresh asset (201). # Unique content per test so the seed always creates a fresh asset (201).
# Delete is now always a soft delete, so content from a prior test survives # Delete is now always a soft delete, so content from a prior test survives

View File

@ -133,6 +133,66 @@ class TestListReferencesPage:
assert total == 1 assert total == 1
assert refs[0].name == "tagged" assert refs[0].name == "tagged"
def test_include_tags_filter_ands_persisted_model_tags(self, session: Session):
asset = _make_asset(session, "hash-model-tags")
checkpoint = _make_reference(session, asset, name="checkpoint")
lora = _make_reference(session, asset, name="lora")
input_ref = _make_reference(session, asset, name="input")
ensure_tags_exist(
session,
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"],
)
add_tags_to_reference(
session,
reference_id=checkpoint.id,
tags=["models", "model_type:checkpoints", "unit-tests"],
origin="automatic",
)
add_tags_to_reference(
session,
reference_id=lora.id,
tags=["models", "model_type:loras", "unit-tests"],
origin="automatic",
)
add_tags_to_reference(
session,
reference_id=input_ref.id,
tags=["unit-tests"],
)
session.commit()
refs, _, total = list_references_page(
session,
include_tags=["models", "model_type:checkpoints", "unit-tests"],
)
assert total == 1
assert refs[0].id == checkpoint.id
def test_include_tags_filter_preserves_model_type_case(self, session: Session):
asset = _make_asset(session, "hash-model-case")
ref = _make_reference(session, asset, name="llm")
ensure_tags_exist(session, ["models", "model_type:LLM"])
add_tags_to_reference(
session,
reference_id=ref.id,
tags=["models", "model_type:LLM"],
origin="automatic",
)
session.commit()
refs, _, total = list_references_page(
session, include_tags=["models", "model_type:LLM"]
)
refs_lower, _, total_lower = list_references_page(
session, include_tags=["models", "model_type:llm"]
)
assert total == 1
assert refs[0].id == ref.id
assert total_lower == 0
assert refs_lower == []
def test_exclude_tags_filter(self, session: Session): def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="keep") _make_reference(session, asset, name="keep")

View File

@ -58,7 +58,7 @@ class TestEnsureTagsExist:
session.commit() session.commit()
tags = session.query(Tag).all() tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"} assert {t.name for t in tags} == {"ALPHA", "Beta", "alpha"}
def test_empty_list_is_noop(self, session: Session): def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, []) ensure_tags_exist(session, [])
@ -258,6 +258,16 @@ class TestListTagsWithUsage:
tag_names = {name for name, _ in rows} tag_names = {name for name, _ in rows}
assert tag_names == {"alpha", "alphabet"} assert tag_names == {"alpha", "alphabet"}
def test_prefix_filter_is_case_sensitive(self, session: Session):
ensure_tags_exist(session, ["model_type:LLM", "model_type:llm"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="model_type:L")
tag_names = {name for name, _ in rows}
assert tag_names == {"model_type:LLM"}
assert total == 1
def test_order_by_name(self, session: Session): def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"]) ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit() session.commit()

View File

@ -6,7 +6,11 @@ from unittest.mock import patch
import pytest import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path from app.assets.services.path_utils import (
get_asset_category_and_relative_path,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
@pytest.fixture @pytest.fixture
@ -76,6 +80,137 @@ class TestGetAssetCategoryAndRelativePath:
cat, rel = get_asset_category_and_relative_path(str(f)) cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "models" assert cat == "models"
def test_model_path_tags_include_registered_model_type_only(self, fake_dirs):
f = fake_dirs["models"] / "subdir" / "model.safetensors"
f.parent.mkdir()
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "checkpoints" not in tags
assert "subdir" not in tags
def test_model_type_preserves_registered_folder_case(self, fake_dirs):
llm_dir = fake_dirs["models"].parent / "LLM"
llm_dir.mkdir()
f = llm_dir / "model.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("LLM", [str(llm_dir)])],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:LLM" in tags
assert "model_type:llm" not in tags
def test_path_components_do_not_create_model_type_tags(self, fake_dirs):
f = fake_dirs["models"] / "loras" / "model.safetensors"
f.parent.mkdir()
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "loras" not in tags
assert "model_type:loras" not in tags
def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs):
shared_root = fake_dirs["models"].parent / "shared"
shared_root.mkdir()
f = shared_root / "foo.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("loras", [str(shared_root)]),
],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "model_type:loras" in tags
def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs):
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
output_checkpoints_dir.mkdir()
f = output_checkpoints_dir / "saved.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "output" in tags
def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs):
f = fake_dirs["temp"] / "preview.png"
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "temp" in tags
assert "output" not in tags
assert "preview:true" not in tags
def test_unknown_path_raises(self, fake_dirs): def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"): with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png") get_asset_category_and_relative_path("/some/random/path.png")
class TestResolveDestinationFromTags:
def test_explicit_subfolder_is_path_component(self, fake_dirs):
base_dir, subdirs = resolve_destination_from_tags(
["input", "unit-tests", "foo"], subfolder="foo/bar"
)
assert base_dir == os.path.abspath(fake_dirs["input"])
assert subdirs == ["foo", "bar"]
@pytest.mark.parametrize(
"subfolder",
["../escape", "foo/../bar", "/abs", "foo\\bar", "C:/escape", "C:escape"],
)
def test_explicit_subfolder_rejects_unsafe_paths(self, fake_dirs, subfolder: str):
with pytest.raises(ValueError, match="invalid subfolder"):
resolve_destination_from_tags(["input", "unit-tests"], subfolder=subfolder)
def test_model_upload_rejects_non_writable_registered_folders(self):
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
checkpoints_dir = root_path / "models" / "checkpoints"
configs_dir = root_path / "models" / "configs"
custom_nodes_dir = root_path / "custom_nodes"
for path in (checkpoints_dir, configs_dir, custom_nodes_dir):
path.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.folder_names_and_paths = {
"checkpoints": ([str(checkpoints_dir)], set()),
"configs": ([str(configs_dir)], set()),
"custom_nodes": ([str(custom_nodes_dir)], set()),
}
base_dir, subdirs = resolve_destination_from_tags(
["models", "model_type:checkpoints"]
)
assert base_dir == os.path.abspath(checkpoints_dir)
assert subdirs == []
for folder_name in ("configs", "custom_nodes"):
with pytest.raises(ValueError, match="unknown model category"):
resolve_destination_from_tags(
["models", f"model_type:{folder_name}"]
)

View File

@ -19,7 +19,8 @@ def test_seed_asset_removed_when_file_is_deleted(
"""Asset without hash (seed) whose file disappears: """Asset without hash (seed) whose file disappears:
after triggering sync_seed_assets, Asset + AssetInfo disappear. after triggering sync_seed_assets, Asset + AssetInfo disappear.
""" """
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests" # Create a file directly under input/unit-tests/<case>. Backend tags only
# classify the root; nested path components are not exposed as tags.
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
case_dir.mkdir(parents=True, exist_ok=True) case_dir.mkdir(parents=True, exist_ok=True)
name = f"seed_{uuid.uuid4().hex[:8]}.bin" name = f"seed_{uuid.uuid4().hex[:8]}.bin"
@ -32,7 +33,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# Verify it is visible via API and carries no hash (seed) # Verify it is visible via API and carries no hash (seed)
r1 = http.get( r1 = http.get(
api_base + "/api/assets", api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name}, params={"include_tags": root, "name_contains": name},
timeout=120, timeout=120,
) )
body1 = r1.json() body1 = r1.json()
@ -54,7 +55,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# It should disappear (AssetInfo and seed Asset gone) # It should disappear (AssetInfo and seed Asset gone)
r2 = http.get( r2 = http.get(
api_base + "/api/assets", api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name}, params={"include_tags": root, "name_contains": name},
timeout=120, timeout=120,
) )
body2 = r2.json() body2 = r2.json()
@ -132,7 +133,7 @@ def test_hashed_asset_two_asset_infos_both_get_missing(
second_id = b2["id"] second_id = b2["id"]
# Remove the single underlying file # Remove the single underlying file
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") p = comfy_tmp_base_dir / "input" / get_asset_filename(created["asset_hash"], ".png")
assert p.exists() assert p.exists()
p.unlink() p.unlink()
@ -250,8 +251,7 @@ def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
a = asset_factory(name, [root, "unit-tests", scope], {}, data) a = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"] aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope p = comfy_tmp_base_dir / root / get_asset_filename(a["asset_hash"], ".bin")
p = base / get_asset_filename(a["asset_hash"], ".bin")
st0 = p.stat() st0 = p.stat()
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))

View File

@ -290,7 +290,7 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
r1 = http.get( r1 = http.get(
api_base + "/api/assets", api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, params={"include_tags": root, "name_contains": name},
timeout=120, timeout=120,
) )
body = r1.json() body = r1.json()

View File

@ -95,7 +95,7 @@ def test_download_chooses_existing_state_and_updates_access_time(
assert t1 > t0 assert t1 > t0
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) @pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "model_type:checkpoints"]}], indirect=True)
def test_download_missing_file_returns_404( def test_download_missing_file_returns_404(
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
): ):

View File

@ -13,7 +13,7 @@ def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
for n in names: for n in names:
asset_factory( asset_factory(
n, n,
["models", "checkpoints", "unit-tests", tag], ["models", "model_type:checkpoints", "unit-tests", tag],
{}, {},
make_asset_bytes(n, size=2048), make_asset_bytes(n, size=2048),
) )
@ -208,7 +208,7 @@ def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api
names = [] names = []
for i in range(4): for i in range(4):
n = f"cursor_{sort_field}_{i:02d}.safetensors" n = f"cursor_{sort_field}_{i:02d}.safetensors"
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i)) asset_factory(n, ["models", "model_type:checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
names.append(n) names.append(n)
params = { params = {

View File

@ -11,7 +11,7 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
for n in names: for n in names:
asset_factory( asset_factory(
n, n,
["models", "checkpoints", "unit-tests", "paging"], ["models", "model_type:checkpoints", "unit-tests", "paging"],
{"epoch": 1}, {"epoch": 1},
make_asset_bytes(n, size=2048), make_asset_bytes(n, size=2048),
) )
@ -45,8 +45,8 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory): def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) a = asset_factory("inc_a.safetensors", ["models", "model_type:checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) b = asset_factory("inc_b.safetensors", ["models", "model_type:checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
r = http.get( r = http.get(
api_base + "/api/assets", api_base + "/api/assets",
@ -81,7 +81,7 @@ def test_list_assets_include_exclude_and_name_contains(http: requests.Session, a
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-size"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-size"]
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
@ -108,7 +108,7 @@ def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, mak
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-upd"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-upd"]
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
@ -131,7 +131,7 @@ def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-access"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-access"]
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
time.sleep(0.02) time.sleep(0.02)
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
@ -154,14 +154,14 @@ def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-include"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-include"]
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
# CSV + case-insensitive # CSV tag filters are whitespace-trimmed and case-sensitive.
r1 = http.get( r1 = http.get(
api_base + "/api/assets", api_base + "/api/assets",
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, params={"include_tags": "unit-tests,lf-include,alpha"},
timeout=120, timeout=120,
) )
b1 = r1.json() b1 = r1.json()
@ -196,14 +196,14 @@ def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factor
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-exclude"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-exclude"]
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
# Exclude uppercase should work # Exclude filters are case-sensitive.
r1 = http.get( r1 = http.get(
api_base + "/api/assets", api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "beta"},
timeout=120, timeout=120,
) )
b1 = r1.json() b1 = r1.json()
@ -225,7 +225,7 @@ def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory,
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-name"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-name"]
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
@ -261,7 +261,7 @@ def test_list_assets_name_contains_case_and_specials(http, api_base, asset_facto
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] t = ["models", "model_type:checkpoints", "unit-tests", "lf-pagelimits"]
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
@ -319,7 +319,7 @@ def test_list_assets_name_contains_literal_underscore(
- foobar.safetensors (must NOT match) - foobar.safetensors (must NOT match)
""" """
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
tags = ["models", "checkpoints", "unit-tests", scope] tags = ["models", "model_type:checkpoints", "unit-tests", scope]
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))

View File

@ -5,7 +5,7 @@ def test_meta_and_across_keys_and_types(
http, api_base: str, asset_factory, make_asset_bytes http, api_base: str, asset_factory, make_asset_bytes
): ):
name = "mf_and_mix.safetensors" name = "mf_and_mix.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-and"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-and"]
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
@ -41,7 +41,7 @@ def test_meta_and_across_keys_and_types(
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
name = "mf_types.safetensors" name = "mf_types.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-types"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-types"]
meta = {"epoch": 1, "active": True} meta = {"epoch": 1, "active": True}
asset_factory(name, tags, meta, make_asset_bytes(name)) asset_factory(name, tags, meta, make_asset_bytes(name))
@ -95,7 +95,7 @@ def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory,
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_scalars.safetensors" name = "mf_list_scalars.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-list"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-list"]
meta = {"flags": ["red", "green"]} meta = {"flags": ["red", "green"]}
asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
@ -134,7 +134,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
http, api_base, asset_factory, make_asset_bytes http, api_base, asset_factory, make_asset_bytes
): ):
# a1: key missing; a2: explicit null; a3: concrete value # a1: key missing; a2: explicit null; a3: concrete value
t = ["models", "checkpoints", "unit-tests", "mf-none"] t = ["models", "model_type:checkpoints", "unit-tests", "mf-none"]
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
@ -166,7 +166,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
name = "mf_nested_json.safetensors" name = "mf_nested_json.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-nested"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-nested"]
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
@ -197,7 +197,7 @@ def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_as
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_objects.safetensors" name = "mf_list_objects.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-objlist"]
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
@ -228,7 +228,7 @@ def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_b
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
name = "mf_keys_unicode.safetensors" name = "mf_keys_unicode.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-keys"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-keys"]
meta = { meta = {
"weird.key": "v1", "weird.key": "v1",
"path/like": 7, "path/like": 7,
@ -259,7 +259,7 @@ def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] t = ["models", "model_type:checkpoints", "unit-tests", "mf-zero-bool"]
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
@ -286,7 +286,7 @@ def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_as
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
name = "mf_mixed_list.safetensors" name = "mf_mixed_list.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] tags = ["models", "model_type:checkpoints", "unit-tests", "mf-mixed"]
meta = {"mix": ["1", 1, True, None]} meta = {"mix": ["1", 1, True, None]}
asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
@ -311,7 +311,7 @@ def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, mak
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
# Use a unique scope tag to avoid interference # Use a unique scope tag to avoid interference
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] t = ["models", "model_type:checkpoints", "unit-tests", "mf-unknown-scope"]
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
@ -340,13 +340,13 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
# alpha matches epoch=1; beta has epoch=2 # alpha matches epoch=1; beta has epoch=2
a = asset_factory( a = asset_factory(
"mf_tag_alpha.safetensors", "mf_tag_alpha.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], ["models", "model_type:checkpoints", "unit-tests", "mf-tag", "alpha"],
{"epoch": 1}, {"epoch": 1},
make_asset_bytes("alpha"), make_asset_bytes("alpha"),
) )
b = asset_factory( b = asset_factory(
"mf_tag_beta.safetensors", "mf_tag_beta.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "beta"], ["models", "model_type:checkpoints", "unit-tests", "mf-tag", "beta"],
{"epoch": 2}, {"epoch": 2},
make_asset_bytes("beta"), make_asset_bytes("beta"),
) )
@ -367,7 +367,7 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
# Three assets in same scope with different sizes and a common filter key # Three assets in same scope with different sizes and a common filter key
t = ["models", "checkpoints", "unit-tests", "mf-sort"] t = ["models", "model_type:checkpoints", "unit-tests", "mf-sort"]
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))

View File

@ -29,7 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
def find_asset(http: requests.Session, api_base: str): def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name.""" """Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]: def _find(scope: str, name: str | None = None) -> list[dict]:
params = {"include_tags": f"unit-tests,{scope}"} params = {"limit": "500"}
if name: if name:
params["name_contains"] = name params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120) r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
@ -91,7 +91,7 @@ def test_hashed_asset_not_pruned_when_file_missing(
data = make_asset_bytes("test", 2048) data = make_asset_bytes("test", 2048)
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data) a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin") path = comfy_tmp_base_dir / "input" / get_asset_filename(a["asset_hash"], ".bin")
path.unlink() path.unlink()
trigger_sync_seed_assets(http, api_base) trigger_sync_seed_assets(http, api_base)
@ -108,18 +108,20 @@ def test_prune_across_multiple_roots(
): ):
"""Prune correctly handles assets across input and output roots.""" """Prune correctly handles assets across input and output roots."""
scope = f"multi-{uuid.uuid4().hex[:6]}" scope = f"multi-{uuid.uuid4().hex[:6]}"
input_fp = create_seed_file("input", scope, "input.bin") input_name = f"{scope}-input.bin"
create_seed_file("output", scope, "output.bin") output_name = f"{scope}-output.bin"
input_fp = create_seed_file("input", scope, input_name)
create_seed_file("output", scope, output_name)
trigger_sync_seed_assets(http, api_base) trigger_sync_seed_assets(http, api_base)
assert len(find_asset(scope)) == 2 assert find_asset(scope, input_name)
assert find_asset(scope, output_name)
input_fp.unlink() input_fp.unlink()
trigger_sync_seed_assets(http, api_base) trigger_sync_seed_assets(http, api_base)
remaining = find_asset(scope) assert not find_asset(scope, input_name)
assert len(remaining) == 1 assert find_asset(scope, output_name)
assert remaining[0]["name"] == "output.bin"
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"]) @pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])

View File

@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
body1 = r1.json() body1 = r1.json()
assert r1.status_code == 200 assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]] names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist: # A few selected contract tags should exist.
assert "models" in names assert "models" in names
assert "checkpoints" in names assert "model_type:checkpoints" in names
# Only used tags before we add anything new from this test cycle # Only used tags before we add anything new from this test cycle
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120) r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
# We already seeded one asset via fixture, so used tags must be non-empty # We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]] used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names assert "models" in used_names
assert "checkpoints" in used_names assert "model_type:checkpoints" in used_names
# Prefix filter should refine the list # Prefix filter should refine the list
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120) r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
body1 = r1.json() body1 = r1.json()
assert r1.status_code == 200 assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]] names = [t["name"] for t in body1["tags"]]
assert "models" in names and "checkpoints" in names assert "models" in names and "model_type:checkpoints" in names
# Create a short-lived asset under input with a unique custom tag # Create a short-lived asset under input with a unique custom tag
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict): def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"] aid = seeded_asset["id"]
# Add tags with duplicates and mixed case # Add tags with duplicates while preserving source case.
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]}
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120) r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
b1 = r1.json() b1 = r1.json()
assert r1.status_code == 200, b1 assert r1.status_code == 200, b1
# normalized, deduplicated; 'unit-tests' was already present from the seed # stripped, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"} assert set(b1["added"]) == {"NewTag", "BETA"}
assert set(b1["already_present"]) == {"unit-tests"} assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json() g = rg.json()
assert rg.status_code == 200 assert rg.status_code == 200
tags_now = set(g["tags"]) tags_now = set(g["tags"])
assert {"newtag", "beta"}.issubset(tags_now) assert {"NewTag", "BETA"}.issubset(tags_now)
# Remove a tag and a non-existent tag # Remove a tag and a non-existent tag
payload_del = {"tags": ["newtag", "does-not-exist"]} payload_del = {"tags": ["NewTag", "does-not-exist"]}
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120) r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
b2 = r2.json() b2 = r2.json()
assert r2.status_code == 200 assert r2.status_code == 200
assert set(b2["removed"]) == {"newtag"} assert set(b2["removed"]) == {"NewTag"}
assert set(b2["not_present"]) == {"does-not-exist"} assert set(b2["not_present"]) == {"does-not-exist"}
# Verify remaining tags after deletion # Verify remaining tags after deletion
@ -118,8 +118,44 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
g2 = rg2.json() g2 = rg2.json()
assert rg2.status_code == 200 assert rg2.status_code == 200
tags_later = set(g2["tags"]) tags_later = set(g2["tags"])
assert "newtag" not in tags_later assert "NewTag" not in tags_later
assert "beta" in tags_later # still present assert "BETA" in tags_later # still present
def test_add_system_looking_tags_allowed_as_labels(
http: requests.Session, api_base: str, seeded_asset: dict
):
aid = seeded_asset["id"]
response = http.post(
f"{api_base}/api/assets/{aid}/tags",
json={
"tags": [
"models",
"model_type:manual",
"model:true",
"models:foo",
"input:true",
"output:true",
"uploaded:true",
"temp:true",
"temporary",
]
},
timeout=120,
)
body = response.json()
assert response.status_code == 200, body
assert "models" in body["total_tags"]
assert "model_type:manual" in body["total_tags"]
assert "model:true" in body["total_tags"]
assert "models:foo" in body["total_tags"]
assert "input:true" in body["total_tags"]
assert "output:true" in body["total_tags"]
assert "uploaded:true" in body["total_tags"]
assert "temp:true" in body["total_tags"]
assert "temporary" in body["total_tags"]
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict): def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@ -1,11 +1,13 @@
import json import json
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import requests import requests
import pytest import pytest
from app.assets.api.schemas_out import Asset, AssetCreated from app.assets.api.schemas_out import Asset, AssetCreated
from helpers import get_asset_filename
def test_asset_created_inherits_hash_field(): def test_asset_created_inherits_hash_field():
@ -22,7 +24,7 @@ def test_asset_created_inherits_hash_field():
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes): def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
name = "dup_a.safetensors" name = "dup_a.safetensors"
tags = ["models", "checkpoints", "unit-tests", "alpha"] tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"} meta = {"purpose": "dup"}
data = make_asset_bytes(name) data = make_asset_bytes(name)
files = {"file": (name, data, "application/octet-stream")} files = {"file": (name, data, "application/octet-stream")}
@ -58,7 +60,7 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
# Seed a small file first # Seed a small file first
name = "fastpath_seed.safetensors" name = "fastpath_seed.safetensors"
tags = ["models", "checkpoints", "unit-tests"] tags = ["input", "unit-tests"]
meta = {} meta = {}
files = {"file": (name, b"B" * 1024, "application/octet-stream")} files = {"file": (name, b"B" * 1024, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
@ -69,9 +71,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert b1["hash"] == h assert b1["hash"] == h
# Now POST /api/assets with only hash and no file # Now POST /api/assets with only hash and no file
hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"]
files = [ files = [
("hash", (None, h)), ("hash", (None, h)),
("tags", (None, json.dumps(tags))), ("tags", (None, json.dumps(hash_only_tags))),
("name", (None, "fastpath_copy.safetensors")), ("name", (None, "fastpath_copy.safetensors")),
("user_metadata", (None, json.dumps({"purpose": "copy"}))), ("user_metadata", (None, json.dumps({"purpose": "copy"}))),
] ]
@ -81,6 +84,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert b2["created_new"] is False assert b2["created_new"] is False
assert b2["asset_hash"] == h assert b2["asset_hash"] == h
assert b2["hash"] == h assert b2["hash"] == h
assert "models" in b2["tags"]
assert "checkpoints" in b2["tags"]
assert "uploaded" not in b2["tags"]
assert not any(tag.startswith("model_type:") for tag in b2["tags"])
def test_upload_fastpath_with_known_hash_and_file( def test_upload_fastpath_with_known_hash_and_file(
@ -88,7 +95,7 @@ def test_upload_fastpath_with_known_hash_and_file(
): ):
# Seed # Seed
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")} files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
b1 = r1.json() b1 = r1.json()
assert r1.status_code == 201, b1 assert r1.status_code == 201, b1
@ -104,11 +111,47 @@ def test_upload_fastpath_with_known_hash_and_file(
assert b2["created_new"] is False assert b2["created_new"] is False
assert b2["asset_hash"] == h assert b2["asset_hash"] == h
assert b2["hash"] == h assert b2["hash"] == h
assert "checkpoints" in b2["tags"]
assert "uploaded" not in b2["tags"]
assert not any(tag == "model_type:checkpoints" for tag in b2["tags"])
def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination(
http: requests.Session, api_base: str
):
data = b"duplicate-reference-only" * 64
seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")}
seed_form = {
"tags": json.dumps(["input", "unit-tests", "duplicate-seed"]),
"name": "duplicate-seed.bin",
"user_metadata": json.dumps({}),
}
seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120)
seed = seed_response.json()
assert seed_response.status_code == 201, seed
duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")}
duplicate_form = {
"tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]),
"name": "duplicate-copy.bin",
"user_metadata": json.dumps({}),
}
duplicate_response = http.post(
api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120
)
duplicate = duplicate_response.json()
assert duplicate_response.status_code == 200, duplicate
assert duplicate["created_new"] is False
assert duplicate["asset_hash"] == seed["asset_hash"]
assert "not-a-destination" in duplicate["tags"]
assert "uploaded" not in duplicate["tags"]
assert "input" not in duplicate["tags"]
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str): def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
data = [ data = [
("tags", "models,checkpoints"), ("tags", "models,model_type:checkpoints"),
("tags", json.dumps(["unit-tests", "alpha"])), ("tags", json.dumps(["unit-tests", "alpha"])),
("name", "merge.safetensors"), ("name", "merge.safetensors"),
("user_metadata", json.dumps({"u": 1})), ("user_metadata", json.dumps({"u": 1})),
@ -124,7 +167,7 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
detail = rg.json() detail = rg.json()
assert rg.status_code == 200, detail assert rg.status_code == 200, detail
tags = set(detail["tags"]) tags = set(detail["tags"])
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize("root", ["input", "output"]) @pytest.mark.parametrize("root", ["input", "output"])
@ -192,16 +235,55 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
assert body["error"]["code"] == "ASSET_NOT_FOUND" assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_accepts_arbitrary_system_looking_tags(
http: requests.Session, api_base: str
):
files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "unit-tests", "hash-seed"]),
"name": "hash-seed.bin",
"user_metadata": json.dumps({}),
}
seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
seed = seed_response.json()
assert seed_response.status_code == 201, seed
response = http.post(
api_base + "/api/assets/from-hash",
json={
"hash": seed["asset_hash"],
"name": "hash-copy.bin",
"tags": [
"models",
"model:true",
"models:foo",
"temporary:true",
"unit-tests",
"hash-copy",
],
},
timeout=120,
)
body = response.json()
assert response.status_code == 201, body
assert "models" in body["tags"]
assert "model:true" in body["tags"]
assert "models:foo" in body["tags"]
assert "temporary:true" in body["tags"]
assert "uploaded" not in body["tags"]
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str): def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
files = {"file": ("empty.safetensors", b"", "application/octet-stream")} files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json() body = r.json()
assert r.status_code == 400 assert r.status_code == 400
assert body["error"]["code"] == "EMPTY_UPLOAD" assert body["error"]["code"] == "EMPTY_UPLOAD"
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str): def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str):
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")} files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})} form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
@ -212,7 +294,7 @@ def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str)
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str): def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")} files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json() body = r.json()
assert r.status_code == 400 assert r.status_code == 400
@ -228,7 +310,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str):
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
files = [ files = [
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))), ("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))),
("name", (None, "x.safetensors")), ("name", (None, "x.safetensors")),
] ]
r = http.post(api_base + "/api/assets", files=files, timeout=120) r = http.post(api_base + "/api/assets", files=files, timeout=120)
@ -237,17 +319,33 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
assert body["error"]["code"] == "MISSING_FILE" assert body["error"]["code"] == "MISSING_FILE"
def test_upload_models_unknown_category(http: requests.Session, api_base: str): def test_upload_models_unknown_model_type(http: requests.Session, api_base: str):
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")} files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"} form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json() body = r.json()
assert r.status_code == 400 assert r.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY" assert body["error"]["code"] == "INVALID_BODY"
assert body["error"]["message"].startswith("unknown models category")
def test_upload_models_requires_category(http: requests.Session, api_base: str): @pytest.mark.parametrize("model_type", ["configs", "custom_nodes"])
def test_upload_models_rejects_non_model_registered_folder(
model_type: str, http: requests.Session, api_base: str
):
files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")}
form = {
"tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]),
"name": "not-a-model.py",
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_models_requires_model_type(http: requests.Session, api_base: str):
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")} files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})} form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
@ -256,13 +354,149 @@ def test_upload_models_requires_category(http: requests.Session, api_base: str):
assert body["error"]["code"] == "INVALID_BODY" assert body["error"]["code"] == "INVALID_BODY"
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str):
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")} files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json() body = r.json()
assert r.status_code == 400 assert r.status_code == 201, body
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") assert ".." in body["tags"]
assert "zzz" in body["tags"]
assert "models" in body["tags"]
assert "model_type:checkpoints" in body["tags"]
def test_upload_subfolder_is_explicit_path_component(
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path
):
files = {"file": ("nested.bin", b"nested" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "unit-tests", "foo"]),
"subfolder": "foo/bar",
"name": "nested.bin",
}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 201, body
stored_name = get_asset_filename(body["asset_hash"], ".bin")
assert (comfy_tmp_base_dir / "input" / "foo" / "bar" / stored_name).exists()
assert "foo" in body["tags"]
def test_upload_rejects_unsafe_subfolder(http: requests.Session, api_base: str):
files = {"file": ("escape.bin", b"escape" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "unit-tests"]),
"subfolder": "../escape",
"name": "escape.bin",
}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_multipart_upload_accepts_system_looking_extra_labels(
http: requests.Session, api_base: str
):
files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(
[
"input",
"unit-tests",
"model:true",
"models:foo",
"temporary",
"uploaded:true",
]
),
"name": "relaxed-labels.bin",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 201, body
assert "input" in body["tags"]
assert "model:true" in body["tags"]
assert "models:foo" in body["tags"]
assert "temporary" in body["tags"]
assert "uploaded:true" in body["tags"]
def test_multipart_upload_rejects_ambiguous_destination_roles(
http: requests.Session, api_base: str
):
files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "output", "unit-tests"]),
"name": "ambiguous.bin",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_multipart_upload_rejects_multiple_model_types_for_models_destination(
http: requests.Session, api_base: str
):
files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"]
),
"name": "ambiguous-model.safetensors",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize(
("tags", "expected_root", "extension"),
[
(["input", "unit-tests", "upload-location-input"], "input", ".bin"),
(["output", "unit-tests", "upload-location-output"], "output", ".bin"),
(
["models", "model_type:checkpoints", "unit-tests", "upload-location-model"],
"models/checkpoints",
".safetensors",
),
],
)
def test_multipart_upload_role_selects_write_location(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
tags: list[str],
expected_root: str,
extension: str,
):
role = next(tag for tag in tags if tag in {"input", "models", "output"})
name = f"{role}-role-upload{extension}"
files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")}
form = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 201, body
stored_name = get_asset_filename(body["asset_hash"], extension)
expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name
assert expected_disk_path.exists()
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str): def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):

View File

@ -29,6 +29,8 @@ class TestFeatureFlags:
features = get_server_features() features = get_server_features()
assert "supports_preview_metadata" in features assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True assert features["supports_preview_metadata"] is True
assert "supports_model_type_tags" in features
assert features["supports_model_type_tags"] is True
assert "max_upload_size" in features assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float)) assert isinstance(features["max_upload_size"], (int, float))

View File

@ -12,6 +12,8 @@ class TestWebSocketFeatureFlags:
# Check expected server features # Check expected server features
assert "supports_preview_metadata" in features assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True assert features["supports_preview_metadata"] is True
assert "supports_model_type_tags" in features
assert features["supports_model_type_tags"] is True
assert "max_upload_size" in features assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float)) assert isinstance(features["max_upload_size"], (int, float))
@ -75,3 +77,5 @@ class TestWebSocketFeatureFlags:
assert server_message["type"] == "feature_flags" assert server_message["type"] == "feature_flags"
assert "supports_preview_metadata" in server_message["data"] assert "supports_preview_metadata" in server_message["data"]
assert server_message["data"]["supports_preview_metadata"] is True assert server_message["data"]["supports_preview_metadata"] is True
assert "supports_model_type_tags" in server_message["data"]
assert server_message["data"]["supports_model_type_tags"] is True