diff --git a/alembic_db/versions/0005_allow_case_sensitive_tags.py b/alembic_db/versions/0005_allow_case_sensitive_tags.py new file mode 100644 index 000000000..8646b350a --- /dev/null +++ b/alembic_db/versions/0005_allow_case_sensitive_tags.py @@ -0,0 +1,71 @@ +""" +Allow case-sensitive tag names. + +Revision ID: 0005_allow_case_sensitive_tags +Revises: 0004_drop_tag_type +Create Date: 2026-06-16 +""" + +from alembic import op + +revision = "0005_allow_case_sensitive_tags" +down_revision = "0004_drop_tag_type" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + bind = op.get_bind() + if bind.dialect.name == "sqlite": + # SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag + # vocabulary table without the lowercase constraint while preserving + # existing tag names. + op.execute("PRAGMA foreign_keys=OFF") + op.execute( + "CREATE TABLE tags_new (" + "name VARCHAR(512) NOT NULL, " + "CONSTRAINT pk_tags PRIMARY KEY (name)" + ")" + ) + op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") + op.execute("DROP TABLE tags") + op.execute("ALTER TABLE tags_new RENAME TO tags") + op.execute("PRAGMA foreign_keys=ON") + return + + op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check") + + +def downgrade() -> None: + # Existing mixed-case tags cannot satisfy the old constraint. Lowercase them + # before restoring it, merging duplicate vocabulary/link rows that collide. + op.execute("INSERT OR IGNORE INTO tags(name) SELECT lower(name) FROM tags") + op.execute( + "DELETE FROM asset_reference_tags " + "WHERE rowid NOT IN (" + " SELECT MIN(rowid) FROM asset_reference_tags " + " GROUP BY asset_reference_id, lower(tag_name)" + ")" + ) + op.execute("UPDATE asset_reference_tags SET tag_name = lower(tag_name)") + op.execute("DELETE FROM tags WHERE name != lower(name)") + + bind = op.get_bind() + if bind.dialect.name == "sqlite": + op.execute("PRAGMA foreign_keys=OFF") + op.execute( + "CREATE TABLE tags_new (" + "name VARCHAR(512) NOT NULL, " + "CONSTRAINT pk_tags PRIMARY KEY (name), " + "CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))" + ")" + ) + op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") + op.execute("DROP TABLE tags") + op.execute("ALTER TABLE tags_new RENAME TO tags") + op.execute("PRAGMA foreign_keys=ON") + return + + op.create_check_constraint( + "ck_tags_ck_tags_lowercase", "tags", "name = lower(name)" + ) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7ef462f5c..bd53552d4 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -10,7 +10,6 @@ from typing import Any from aiohttp import web from pydantic import ValidationError -import folder_paths from app import user_manager from app.assets.api import schemas_in, schemas_out from app.assets.services import schemas @@ -40,6 +39,7 @@ from app.assets.services import ( upload_from_temp_path, ) from app.assets.services.cursor import InvalidCursorError +from app.assets.services.path_utils import compute_api_file_path from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() @@ -169,6 +169,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu asset_hash=asset_content_hash, size=int(result.asset.size_bytes) if result.asset else None, mime_type=result.asset.mime_type if result.asset else None, + file_path=compute_api_file_path(result.ref.file_path), tags=result.tags, preview_url=preview_url, preview_id=result.ref.preview_id, @@ -416,17 +417,6 @@ async def upload_asset(request: web.Request) -> web.Response: 400, "INVALID_BODY", f"Validation failed: {ve.json()}" ) - if spec.tags and spec.tags[0] == "models": - if ( - len(spec.tags) < 2 - or spec.tags[1] not in folder_paths.folder_names_and_paths - ): - delete_temp_file_if_exists(parsed.tmp_path) - category = spec.tags[1] if len(spec.tags) >= 2 else "" - return _build_error_response( - 400, "INVALID_BODY", f"unknown models category '{category}'" - ) - try: # Fast path: hash exists, create AssetReference without writing anything if spec.hash and parsed.provided_hash_exists is True: @@ -470,7 +460,7 @@ async def upload_asset(request: web.Request) -> web.Response: return _build_error_response(400, e.code, str(e)) except ValueError as e: delete_temp_file_if_exists(parsed.tmp_path) - return _build_error_response(400, "BAD_REQUEST", str(e)) + return _build_error_response(400, "INVALID_BODY", str(e)) except HashMismatchError as e: delete_temp_file_if_exists(parsed.tmp_path) return _build_error_response(400, "HASH_MISMATCH", str(e)) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index af666746d..33565cad6 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -140,7 +140,7 @@ class CreateFromHashBody(BaseModel): if v is None: return [] if isinstance(v, list): - out = [str(t).strip().lower() for t in v if str(t).strip()] + out = [str(t).strip() for t in v if str(t).strip()] seen = set() dedup = [] for t in out: @@ -149,7 +149,7 @@ class CreateFromHashBody(BaseModel): dedup.append(t) return dedup if isinstance(v, str): - return [t.strip().lower() for t in v.split(",") if t.strip()] + return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip())) return [] @@ -206,7 +206,7 @@ class TagsListQuery(BaseModel): if v is None: return v v = v.strip() - return v.lower() or None + return v or None class TagsAdd(BaseModel): @@ -220,7 +220,7 @@ class TagsAdd(BaseModel): for t in v: if not isinstance(t, str): raise TypeError("tags must be strings") - tnorm = t.strip().lower() + tnorm = t.strip() if tnorm: out.append(tnorm) seen = set() @@ -309,7 +309,7 @@ class UploadAssetSpec(BaseModel): norm = [] seen = set() for t in items: - tnorm = str(t).strip().lower() + tnorm = str(t).strip() if tnorm and tnorm not in seen: seen.add(tnorm) norm.append(tnorm) @@ -335,14 +335,4 @@ class UploadAssetSpec(BaseModel): @model_validator(mode="after") def _validate_order(self): - if not self.tags: - raise ValueError("at least one tag is required for uploads") - root = self.tags[0] - if root not in {"models", "input", "output"}: - raise ValueError("first tag must be one of: models, input, output") - if root == "models": - if len(self.tags) < 2: - raise ValueError( - "models uploads require a category tag as the second tag" - ) return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 4e38e19d1..4214aeb0e 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -14,6 +14,7 @@ class Asset(BaseModel): asset_hash: str | None = None size: int | None = None mime_type: str | None = None + file_path: str | None = None tags: list[str] = Field(default_factory=list) preview_url: str | None = None preview_id: str | None = None # references an asset_reference id, not an asset id diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index d41d73a10..6e041d637 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -294,7 +294,7 @@ def list_tags_with_usage( ) if prefix: - escaped, esc = escape_sql_like_string(prefix.strip().lower()) + escaped, esc = escape_sql_like_string(prefix.strip()) q = q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: @@ -307,7 +307,7 @@ def list_tags_with_usage( total_q = select(func.count()).select_from(Tag) if prefix: - escaped, esc = escape_sql_like_string(prefix.strip().lower()) + escaped, esc = escape_sql_like_string(prefix.strip()) total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: visible_tags_sq = ( diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 3798f3933..87734d0dc 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -41,10 +41,10 @@ def get_utc_now() -> datetime: def normalize_tags(tags: list[str] | None) -> list[str]: """ Normalize a list of tags by: - - Stripping whitespace and converting to lowercase. - - Removing duplicates. + - Stripping whitespace. + - Removing exact duplicates while preserving order and case. """ - return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) + return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip())) def validate_blake3_hash(s: str) -> str: diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index 3b6dc237c..8b2021a61 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -34,6 +34,7 @@ from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.path_utils import ( compute_relative_filename, + get_backend_system_tags_from_path, get_name_and_tags_from_asset_path, resolve_destination_from_tags, validate_path_within_base, @@ -101,7 +102,11 @@ def _ingest_file_from_path( if preview_id and ref.preview_id != preview_id: ref.preview_id = preview_id - norm = normalize_tags(list(tags)) + try: + backend_tags = get_backend_system_tags_from_path(locator) + except ValueError: + backend_tags = [] + norm = normalize_tags([*list(tags), *backend_tags]) if norm: if require_existing_tags: validate_tags_exist(session, norm) @@ -474,6 +479,10 @@ def upload_from_temp_path( existing = get_asset_by_hash(session, asset_hash=asset_hash) if existing is not None: + # Once content is already known, duplicate byte uploads are treated as + # reference-only creation. Request tags are labels only here: do not + # require upload destination tags, do not move bytes, and do not + # synthesize path-derived classification or uploaded provenance. with contextlib.suppress(Exception): if temp_path and os.path.exists(temp_path): os.remove(temp_path) @@ -535,7 +544,7 @@ def upload_from_temp_path( owner_id=owner_id, preview_id=preview_id, user_metadata=user_metadata or {}, - tags=tags, + tags=[*(tags or []), "uploaded"], tag_origin="manual", require_existing_tags=False, ) @@ -569,15 +578,19 @@ def register_file_in_place( ) -> UploadResult: """Register an already-saved file in the asset database without moving it. - Tags are derived from the filesystem path (root category + subfolder names), - merged with any caller-provided tags, matching the behavior of the scanner. + This helper is used by upload paths that have already written bytes before + registering the file, so it records the same ``uploaded`` tag as the + multipart byte-upload path. + + Tags are derived from trusted filesystem classification and merged with any + caller-provided tags, matching the behavior of the scanner. If the path is not under a known root, only the caller-provided tags are used. """ try: _, path_tags = get_name_and_tags_from_asset_path(abs_path) except ValueError: path_tags = [] - merged_tags = normalize_tags([*path_tags, *tags]) + merged_tags = normalize_tags([*path_tags, *tags, "uploaded"]) try: digest, _ = hashing.compute_blake3_hash(abs_path) diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index 892140ffb..0e4656fe7 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -3,10 +3,9 @@ from pathlib import Path from typing import Literal import folder_paths -from app.assets.helpers import normalize_tags -_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) +_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"}) def get_comfy_models_folders() -> list[tuple[str, list[str]]]: @@ -14,7 +13,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: Includes every category registered in folder_names_and_paths, regardless of whether its paths are under the main models_dir, - but excludes non-model entries like custom_nodes. + but excludes non-model entries like configs and custom_nodes. """ targets: list[tuple[str, list[str]]] = [] for name, values in folder_paths.folder_names_and_paths.items(): @@ -27,35 +26,37 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: - """Validates and maps tags -> (base_dir, subdirs_for_fs)""" - if not tags: - raise ValueError("tags must not be empty") - root = tags[0].lower() + """Validates and maps upload routing tags -> (base_dir, subdirs_for_fs). + + The request tags are only used to choose the write destination. Extra tags + remain labels; they do not become path components or trusted classification. + """ + destination_roles = [t for t in tags if t in {"input", "models", "output"}] + if len(destination_roles) != 1: + raise ValueError("uploads require exactly one destination role: input, models, or output") + + root = destination_roles[0] if root == "models": - if len(tags) < 2: - raise ValueError("at least two tags required for model asset") + model_type_tags = [t for t in tags if t.startswith("model_type:")] + if len(model_type_tags) != 1: + raise ValueError("models uploads require exactly one model_type: tag") + folder_name = model_type_tags[0].split(":", 1)[1] + if not folder_name: + raise ValueError("models uploads require exactly one model_type: tag") + model_folder_paths = dict(get_comfy_models_folders()) try: - bases = folder_paths.folder_names_and_paths[tags[1]][0] + bases = model_folder_paths[folder_name] except KeyError: - raise ValueError(f"unknown model category '{tags[1]}'") + raise ValueError(f"unknown model category '{folder_name}'") if not bases: - raise ValueError(f"no base path configured for category '{tags[1]}'") + raise ValueError(f"no base path configured for category '{folder_name}'") base_dir = os.path.abspath(bases[0]) - raw_subdirs = tags[2:] elif root == "input": base_dir = os.path.abspath(folder_paths.get_input_directory()) - raw_subdirs = tags[1:] - elif root == "output": - base_dir = os.path.abspath(folder_paths.get_output_directory()) - raw_subdirs = tags[1:] else: - raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") - _sep_chars = frozenset(("/", "\\", os.sep)) - for i in raw_subdirs: - if i in (".", "..") or _sep_chars & set(i): - raise ValueError("invalid path component in tags") + base_dir = os.path.abspath(folder_paths.get_output_directory()) - return base_dir, raw_subdirs if raw_subdirs else [] + return base_dir, [] def validate_path_within_base(candidate: str, base: str) -> None: @@ -91,6 +92,25 @@ def compute_relative_filename(file_path: str) -> str | None: return "/".join(parts) # input/output: keep all parts +def compute_api_file_path(file_path: str | None) -> str | None: + """Return a stable API-visible path relative to a known asset root. + + Examples: + /.../input/foo.png -> "input/foo.png" + /.../models/checkpoints/foo.safetensors -> "models/checkpoints/foo.safetensors" + + Returns None for references without a filesystem path or paths outside + known asset roots. + """ + if not file_path: + return None + try: + root_category, rel_path = get_asset_category_and_relative_path(file_path) + except ValueError: + return None + return "/".join([root_category, *Path(rel_path).parts]) + + def get_asset_category_and_relative_path( file_path: str, ) -> tuple[Literal["input", "output", "temp", "models"], str]: @@ -156,18 +176,55 @@ def get_asset_category_and_relative_path( ) +def get_backend_system_tags_from_path(path: str) -> list[str]: + """Return trusted backend tags derived from current filesystem facts. + + The returned tags are only the backend-generated system tags: ``models``, + ``model_type:``, ``input``, ``output``, and ``temp``. Model + type tags are based on registered folder names, not path components. + """ + fp_abs = os.path.abspath(path) + fp_path = Path(fp_abs) + tags: list[str] = [] + + def _add(tag: str) -> None: + if tag not in tags: + tags.append(tag) + + for role, base in ( + ("input", folder_paths.get_input_directory()), + ("output", folder_paths.get_output_directory()), + ("temp", folder_paths.get_temp_directory()), + ): + if fp_path.is_relative_to(os.path.abspath(base)): + _add(role) + + model_types: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + for base in bases: + if fp_path.is_relative_to(os.path.abspath(base)): + model_types.append(folder_name) + break + + if model_types: + _add("models") + for folder_name in model_types: + _add(f"model_type:{folder_name}") + + if not tags: + raise ValueError( + f"Path is not within input, output, temp, or configured model bases: {path}" + ) + return tags + + def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: """Return (name, tags) derived from a filesystem path. - name: base filename with extension - - tags: [root_category] + parent folder names in order + - tags: trusted backend classification tags derived from the path Raises: ValueError: path does not belong to any known root. """ - root_category, some_path = get_asset_category_and_relative_path(file_path) - p = Path(some_path) - parent_parts = [ - part for part in p.parent.parts if part not in (".", "..", p.anchor) - ] - return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) + return Path(file_path).name, get_backend_system_tags_from_path(file_path) diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 4aa20372f..44416e8c5 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas p = getattr(request, "param", {}) or {} tags: Optional[list[str]] = p.get("tags") if tags is None: - tags = ["models", "checkpoints", "unit-tests", "alpha"] + tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"] meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} # Unique content per test so the seed always creates a fresh asset (201). # Delete is now always a soft delete, so content from a prior test survives diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index fe510e342..74dfb8a37 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -133,6 +133,66 @@ class TestListReferencesPage: assert total == 1 assert refs[0].name == "tagged" + def test_include_tags_filter_ands_persisted_model_tags(self, session: Session): + asset = _make_asset(session, "hash-model-tags") + checkpoint = _make_reference(session, asset, name="checkpoint") + lora = _make_reference(session, asset, name="lora") + input_ref = _make_reference(session, asset, name="input") + ensure_tags_exist( + session, + ["models", "model_type:checkpoints", "model_type:loras", "unit-tests"], + ) + add_tags_to_reference( + session, + reference_id=checkpoint.id, + tags=["models", "model_type:checkpoints", "unit-tests"], + origin="automatic", + ) + add_tags_to_reference( + session, + reference_id=lora.id, + tags=["models", "model_type:loras", "unit-tests"], + origin="automatic", + ) + add_tags_to_reference( + session, + reference_id=input_ref.id, + tags=["unit-tests"], + ) + session.commit() + + refs, _, total = list_references_page( + session, + include_tags=["models", "model_type:checkpoints", "unit-tests"], + ) + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_include_tags_filter_preserves_model_type_case(self, session: Session): + asset = _make_asset(session, "hash-model-case") + ref = _make_reference(session, asset, name="llm") + ensure_tags_exist(session, ["models", "model_type:LLM"]) + add_tags_to_reference( + session, + reference_id=ref.id, + tags=["models", "model_type:LLM"], + origin="automatic", + ) + session.commit() + + refs, _, total = list_references_page( + session, include_tags=["models", "model_type:LLM"] + ) + refs_lower, _, total_lower = list_references_page( + session, include_tags=["models", "model_type:llm"] + ) + + assert total == 1 + assert refs[0].id == ref.id + assert total_lower == 0 + assert refs_lower == [] + def test_exclude_tags_filter(self, session: Session): asset = _make_asset(session, "hash1") _make_reference(session, asset, name="keep") diff --git a/tests-unit/assets_test/services/test_path_utils.py b/tests-unit/assets_test/services/test_path_utils.py index 3fa905f9a..5c4871c1f 100644 --- a/tests-unit/assets_test/services/test_path_utils.py +++ b/tests-unit/assets_test/services/test_path_utils.py @@ -6,7 +6,11 @@ from unittest.mock import patch import pytest -from app.assets.services.path_utils import get_asset_category_and_relative_path +from app.assets.services.path_utils import ( + get_asset_category_and_relative_path, + get_name_and_tags_from_asset_path, + resolve_destination_from_tags, +) @pytest.fixture @@ -76,6 +80,121 @@ class TestGetAssetCategoryAndRelativePath: cat, rel = get_asset_category_and_relative_path(str(f)) assert cat == "models" + def test_model_path_tags_include_registered_model_type_only(self, fake_dirs): + f = fake_dirs["models"] / "subdir" / "model.safetensors" + f.parent.mkdir() + f.touch() + + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "checkpoints" not in tags + assert "subdir" not in tags + + def test_model_type_preserves_registered_folder_case(self, fake_dirs): + llm_dir = fake_dirs["models"].parent / "LLM" + llm_dir.mkdir() + f = llm_dir / "model.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("LLM", [str(llm_dir)])], + ): + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:LLM" in tags + assert "model_type:llm" not in tags + + def test_path_components_do_not_create_model_type_tags(self, fake_dirs): + f = fake_dirs["models"] / "loras" / "model.safetensors" + f.parent.mkdir() + f.touch() + + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "loras" not in tags + assert "model_type:loras" not in tags + + def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs): + shared_root = fake_dirs["models"].parent / "shared" + shared_root.mkdir() + f = shared_root / "foo.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("checkpoints", [str(shared_root)]), + ("loras", [str(shared_root)]), + ], + ): + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "model_type:loras" in tags + + def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs): + output_checkpoints_dir = fake_dirs["output"] / "checkpoints" + output_checkpoints_dir.mkdir() + f = output_checkpoints_dir / "saved.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(output_checkpoints_dir)])], + ): + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "output" in tags + + def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs): + f = fake_dirs["temp"] / "preview.png" + f.touch() + + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "temp" in tags + assert "output" not in tags + assert "preview:true" not in tags + def test_unknown_path_raises(self, fake_dirs): with pytest.raises(ValueError, match="not within"): get_asset_category_and_relative_path("/some/random/path.png") + + +class TestResolveDestinationFromTags: + def test_model_upload_rejects_non_writable_registered_folders(self): + with tempfile.TemporaryDirectory() as root: + root_path = Path(root) + checkpoints_dir = root_path / "models" / "checkpoints" + configs_dir = root_path / "models" / "configs" + custom_nodes_dir = root_path / "custom_nodes" + for path in (checkpoints_dir, configs_dir, custom_nodes_dir): + path.mkdir(parents=True) + + with patch("app.assets.services.path_utils.folder_paths") as mock_fp: + mock_fp.folder_names_and_paths = { + "checkpoints": ([str(checkpoints_dir)], set()), + "configs": ([str(configs_dir)], set()), + "custom_nodes": ([str(custom_nodes_dir)], set()), + } + + base_dir, subdirs = resolve_destination_from_tags( + ["models", "model_type:checkpoints"] + ) + assert base_dir == os.path.abspath(checkpoints_dir) + assert subdirs == [] + + for folder_name in ("configs", "custom_nodes"): + with pytest.raises(ValueError, match="unknown model category"): + resolve_destination_from_tags( + ["models", f"model_type:{folder_name}"] + ) diff --git a/tests-unit/assets_test/test_tags_api.py b/tests-unit/assets_test/test_tags_api.py index 9729b7d03..93786696f 100644 --- a/tests-unit/assets_test/test_tags_api.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict) body1 = r1.json() assert r1.status_code == 200 names = [t["name"] for t in body1["tags"]] - # A few system tags from migration should exist: + # A few selected contract tags should exist. assert "models" in names - assert "checkpoints" in names + assert "model_type:checkpoints" in names # Only used tags before we add anything new from this test cycle r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120) @@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict) # We already seeded one asset via fixture, so used tags must be non-empty used_names = [t["name"] for t in body2["tags"]] assert "models" in used_names - assert "checkpoints" in used_names + assert "model_type:checkpoints" in used_names # Prefix filter should refine the list r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120) @@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, body1 = r1.json() assert r1.status_code == 200 names = [t["name"] for t in body1["tags"]] - assert "models" in names and "checkpoints" in names + assert "models" in names and "model_type:checkpoints" in names # Create a short-lived asset under input with a unique custom tag scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" @@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict): aid = seeded_asset["id"] - # Add tags with duplicates and mixed case - payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + # Add tags with duplicates while preserving source case. + payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]} r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120) b1 = r1.json() assert r1.status_code == 200, b1 - # normalized, deduplicated; 'unit-tests' was already present from the seed - assert set(b1["added"]) == {"newtag", "beta"} + # stripped, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"NewTag", "BETA"} assert set(b1["already_present"]) == {"unit-tests"} - assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"] rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) g = rg.json() assert rg.status_code == 200 tags_now = set(g["tags"]) - assert {"newtag", "beta"}.issubset(tags_now) + assert {"NewTag", "BETA"}.issubset(tags_now) # Remove a tag and a non-existent tag - payload_del = {"tags": ["newtag", "does-not-exist"]} + payload_del = {"tags": ["NewTag", "does-not-exist"]} r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120) b2 = r2.json() assert r2.status_code == 200 - assert set(b2["removed"]) == {"newtag"} + assert set(b2["removed"]) == {"NewTag"} assert set(b2["not_present"]) == {"does-not-exist"} # Verify remaining tags after deletion @@ -118,8 +118,44 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset g2 = rg2.json() assert rg2.status_code == 200 tags_later = set(g2["tags"]) - assert "newtag" not in tags_later - assert "beta" in tags_later # still present + assert "NewTag" not in tags_later + assert "BETA" in tags_later # still present + + +def test_add_system_looking_tags_allowed_as_labels( + http: requests.Session, api_base: str, seeded_asset: dict +): + aid = seeded_asset["id"] + + response = http.post( + f"{api_base}/api/assets/{aid}/tags", + json={ + "tags": [ + "models", + "model_type:manual", + "model:true", + "models:foo", + "input:true", + "output:true", + "uploaded:true", + "temp:true", + "temporary", + ] + }, + timeout=120, + ) + body = response.json() + + assert response.status_code == 200, body + assert "models" in body["total_tags"] + assert "model_type:manual" in body["total_tags"] + assert "model:true" in body["total_tags"] + assert "models:foo" in body["total_tags"] + assert "input:true" in body["total_tags"] + assert "output:true" in body["total_tags"] + assert "uploaded:true" in body["total_tags"] + assert "temp:true" in body["total_tags"] + assert "temporary" in body["total_tags"] def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict): diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index 427a417cc..123be298b 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -1,11 +1,13 @@ import json import uuid from concurrent.futures import ThreadPoolExecutor +from pathlib import Path import requests import pytest from app.assets.api.schemas_out import Asset, AssetCreated +from helpers import get_asset_filename def test_asset_created_inherits_hash_field(): @@ -22,7 +24,7 @@ def test_asset_created_inherits_hash_field(): def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes): name = "dup_a.safetensors" - tags = ["models", "checkpoints", "unit-tests", "alpha"] + tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"] meta = {"purpose": "dup"} data = make_asset_bytes(name) files = {"file": (name, data, "application/octet-stream")} @@ -58,7 +60,7 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): # Seed a small file first name = "fastpath_seed.safetensors" - tags = ["models", "checkpoints", "unit-tests"] + tags = ["input", "unit-tests"] meta = {} files = {"file": (name, b"B" * 1024, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} @@ -69,9 +71,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_ assert b1["hash"] == h # Now POST /api/assets with only hash and no file + hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"] files = [ ("hash", (None, h)), - ("tags", (None, json.dumps(tags))), + ("tags", (None, json.dumps(hash_only_tags))), ("name", (None, "fastpath_copy.safetensors")), ("user_metadata", (None, json.dumps({"purpose": "copy"}))), ] @@ -81,6 +84,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_ assert b2["created_new"] is False assert b2["asset_hash"] == h assert b2["hash"] == h + assert "models" in b2["tags"] + assert "checkpoints" in b2["tags"] + assert "uploaded" not in b2["tags"] + assert not any(tag.startswith("model_type:") for tag in b2["tags"]) def test_upload_fastpath_with_known_hash_and_file( @@ -88,7 +95,7 @@ def test_upload_fastpath_with_known_hash_and_file( ): # Seed files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) b1 = r1.json() assert r1.status_code == 201, b1 @@ -104,11 +111,47 @@ def test_upload_fastpath_with_known_hash_and_file( assert b2["created_new"] is False assert b2["asset_hash"] == h assert b2["hash"] == h + assert "checkpoints" in b2["tags"] + assert "uploaded" not in b2["tags"] + assert not any(tag == "model_type:checkpoints" for tag in b2["tags"]) + + +def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination( + http: requests.Session, api_base: str +): + data = b"duplicate-reference-only" * 64 + seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")} + seed_form = { + "tags": json.dumps(["input", "unit-tests", "duplicate-seed"]), + "name": "duplicate-seed.bin", + "user_metadata": json.dumps({}), + } + seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120) + seed = seed_response.json() + assert seed_response.status_code == 201, seed + + duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")} + duplicate_form = { + "tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]), + "name": "duplicate-copy.bin", + "user_metadata": json.dumps({}), + } + duplicate_response = http.post( + api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120 + ) + duplicate = duplicate_response.json() + + assert duplicate_response.status_code == 200, duplicate + assert duplicate["created_new"] is False + assert duplicate["asset_hash"] == seed["asset_hash"] + assert "not-a-destination" in duplicate["tags"] + assert "uploaded" not in duplicate["tags"] + assert "input" not in duplicate["tags"] def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str): data = [ - ("tags", "models,checkpoints"), + ("tags", "models,model_type:checkpoints"), ("tags", json.dumps(["unit-tests", "alpha"])), ("name", "merge.safetensors"), ("user_metadata", json.dumps({"u": 1})), @@ -124,7 +167,7 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base detail = rg.json() assert rg.status_code == 200, detail tags = set(detail["tags"]) - assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags) @pytest.mark.parametrize("root", ["input", "output"]) @@ -192,16 +235,55 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str): assert body["error"]["code"] == "ASSET_NOT_FOUND" +def test_create_from_hash_accepts_arbitrary_system_looking_tags( + http: requests.Session, api_base: str +): + files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")} + form = { + "tags": json.dumps(["input", "unit-tests", "hash-seed"]), + "name": "hash-seed.bin", + "user_metadata": json.dumps({}), + } + seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + seed = seed_response.json() + assert seed_response.status_code == 201, seed + + response = http.post( + api_base + "/api/assets/from-hash", + json={ + "hash": seed["asset_hash"], + "name": "hash-copy.bin", + "tags": [ + "models", + "model:true", + "models:foo", + "temporary:true", + "unit-tests", + "hash-copy", + ], + }, + timeout=120, + ) + body = response.json() + + assert response.status_code == 201, body + assert "models" in body["tags"] + assert "model:true" in body["tags"] + assert "models:foo" in body["tags"] + assert "temporary:true" in body["tags"] + assert "uploaded" not in body["tags"] + + def test_upload_zero_byte_rejected(http: requests.Session, api_base: str): files = {"file": ("empty.safetensors", b"", "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() assert r.status_code == 400 assert body["error"]["code"] == "EMPTY_UPLOAD" -def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str): +def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str): files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")} form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) @@ -212,7 +294,7 @@ def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str) def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str): files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() assert r.status_code == 400 @@ -228,7 +310,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str): def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): files = [ - ("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))), + ("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))), ("name", (None, "x.safetensors")), ] r = http.post(api_base + "/api/assets", files=files, timeout=120) @@ -237,17 +319,33 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): assert body["error"]["code"] == "MISSING_FILE" -def test_upload_models_unknown_category(http: requests.Session, api_base: str): +def test_upload_models_unknown_model_type(http: requests.Session, api_base: str): files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")} - form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"} + form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() - assert r.status_code == 400 + assert r.status_code == 400, body assert body["error"]["code"] == "INVALID_BODY" - assert body["error"]["message"].startswith("unknown models category") -def test_upload_models_requires_category(http: requests.Session, api_base: str): +@pytest.mark.parametrize("model_type", ["configs", "custom_nodes"]) +def test_upload_models_rejects_non_model_registered_folder( + model_type: str, http: requests.Session, api_base: str +): + files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")} + form = { + "tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]), + "name": "not-a-model.py", + } + + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_models_requires_model_type(http: requests.Session, api_base: str): files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")} form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) @@ -256,13 +354,118 @@ def test_upload_models_requires_category(http: requests.Session, api_base: str): assert body["error"]["code"] == "INVALID_BODY" -def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): +def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str): files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() - assert r.status_code == 400 - assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + assert r.status_code == 201, body + assert ".." in body["tags"] + assert "zzz" in body["tags"] + assert "models" in body["tags"] + assert "model_type:checkpoints" in body["tags"] + + +def test_multipart_upload_accepts_system_looking_extra_labels( + http: requests.Session, api_base: str +): + files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")} + form = { + "tags": json.dumps( + [ + "input", + "unit-tests", + "model:true", + "models:foo", + "temporary", + "uploaded:true", + ] + ), + "name": "relaxed-labels.bin", + "user_metadata": json.dumps({}), + } + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 201, body + assert "input" in body["tags"] + assert "model:true" in body["tags"] + assert "models:foo" in body["tags"] + assert "temporary" in body["tags"] + assert "uploaded:true" in body["tags"] + + +def test_multipart_upload_rejects_ambiguous_destination_roles( + http: requests.Session, api_base: str +): + files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")} + form = { + "tags": json.dumps(["input", "output", "unit-tests"]), + "name": "ambiguous.bin", + "user_metadata": json.dumps({}), + } + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +def test_multipart_upload_rejects_multiple_model_types_for_models_destination( + http: requests.Session, api_base: str +): + files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")} + form = { + "tags": json.dumps( + ["models", "model_type:checkpoints", "model_type:loras", "unit-tests"] + ), + "name": "ambiguous-model.safetensors", + "user_metadata": json.dumps({}), + } + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.parametrize( + ("tags", "expected_root", "extension"), + [ + (["input", "unit-tests", "upload-location-input"], "input", ".bin"), + (["output", "unit-tests", "upload-location-output"], "output", ".bin"), + ( + ["models", "model_type:checkpoints", "unit-tests", "upload-location-model"], + "models/checkpoints", + ".safetensors", + ), + ], +) +def test_multipart_upload_role_selects_write_location( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + tags: list[str], + expected_root: str, + extension: str, +): + role = next(tag for tag in tags if tag in {"input", "models", "output"}) + name = f"{role}-role-upload{extension}" + files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")} + form = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps({}), + } + + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 201, body + stored_name = get_asset_filename(body["asset_hash"], extension) + expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name + assert expected_disk_path.exists() + assert body["file_path"] == f"{expected_root}/{stored_name}" def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):