diff --git a/alembic_db/versions/0005_allow_case_sensitive_tags.py b/alembic_db/versions/0005_allow_case_sensitive_tags.py new file mode 100644 index 000000000..bd5f864db --- /dev/null +++ b/alembic_db/versions/0005_allow_case_sensitive_tags.py @@ -0,0 +1,107 @@ +""" +Allow case-sensitive tag names. + +Revision ID: 0005_allow_case_sensitive_tags +Revises: 0004_drop_tag_type +Create Date: 2026-06-16 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "0005_allow_case_sensitive_tags" +down_revision = "0004_drop_tag_type" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + bind = op.get_bind() + if bind.dialect.name == "sqlite": + # SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag + # vocabulary table without the lowercase constraint while preserving + # existing tag names. + op.execute("PRAGMA foreign_keys=OFF") + try: + op.execute( + "CREATE TABLE tags_new (" + "name VARCHAR(512) NOT NULL, " + "CONSTRAINT pk_tags PRIMARY KEY (name)" + ")" + ) + op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") + op.execute("DROP TABLE tags") + op.execute("ALTER TABLE tags_new RENAME TO tags") + finally: + op.execute("PRAGMA foreign_keys=ON") + return + + op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check") + + +def downgrade() -> None: + # Existing mixed-case tags cannot satisfy the old constraint. Lowercase them + # before restoring it, merging duplicate vocabulary/link rows that collide. + bind = op.get_bind() + + tag_names = [row[0] for row in bind.execute(sa.text("SELECT name FROM tags"))] + existing_names = set(tag_names) + lowercase_names = sorted({name.lower() for name in tag_names}) + missing_lowercase_rows = [ + {"name": name} for name in lowercase_names if name not in existing_names + ] + if missing_lowercase_rows: + bind.execute(sa.text("INSERT INTO tags(name) VALUES (:name)"), missing_lowercase_rows) + + link_rows = bind.execute( + sa.text( + "SELECT asset_reference_id, tag_name, origin, added_at " + "FROM asset_reference_tags " + "ORDER BY asset_reference_id, tag_name" + ) + ).mappings() + deduped_links = {} + for row in link_rows: + key = (row["asset_reference_id"], row["tag_name"].lower()) + deduped_links.setdefault( + key, + { + "asset_reference_id": row["asset_reference_id"], + "tag_name": row["tag_name"].lower(), + "origin": row["origin"], + "added_at": row["added_at"], + }, + ) + + op.execute("DELETE FROM asset_reference_tags") + if deduped_links: + bind.execute( + sa.text( + "INSERT INTO asset_reference_tags " + "(asset_reference_id, tag_name, origin, added_at) " + "VALUES (:asset_reference_id, :tag_name, :origin, :added_at)" + ), + list(deduped_links.values()), + ) + op.execute("DELETE FROM tags WHERE name != lower(name)") + + if bind.dialect.name == "sqlite": + op.execute("PRAGMA foreign_keys=OFF") + try: + op.execute( + "CREATE TABLE tags_new (" + "name VARCHAR(512) NOT NULL, " + "CONSTRAINT pk_tags PRIMARY KEY (name), " + "CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))" + ")" + ) + op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") + op.execute("DROP TABLE tags") + op.execute("ALTER TABLE tags_new RENAME TO tags") + finally: + op.execute("PRAGMA foreign_keys=ON") + return + + op.create_check_constraint( + "ck_tags_ck_tags_lowercase", "tags", "name = lower(name)" + ) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7ef462f5c..f40211f6c 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -10,7 +10,6 @@ from typing import Any from aiohttp import web from pydantic import ValidationError -import folder_paths from app import user_manager from app.assets.api import schemas_in, schemas_out from app.assets.services import schemas @@ -408,6 +407,7 @@ async def upload_asset(request: web.Request) -> web.Response: "hash": parsed.provided_hash, "mime_type": parsed.provided_mime_type, "preview_id": parsed.provided_preview_id, + "subfolder": parsed.provided_subfolder, } ) except ValidationError as ve: @@ -416,17 +416,6 @@ async def upload_asset(request: web.Request) -> web.Response: 400, "INVALID_BODY", f"Validation failed: {ve.json()}" ) - if spec.tags and spec.tags[0] == "models": - if ( - len(spec.tags) < 2 - or spec.tags[1] not in folder_paths.folder_names_and_paths - ): - delete_temp_file_if_exists(parsed.tmp_path) - category = spec.tags[1] if len(spec.tags) >= 2 else "" - return _build_error_response( - 400, "INVALID_BODY", f"unknown models category '{category}'" - ) - try: # Fast path: hash exists, create AssetReference without writing anything if spec.hash and parsed.provided_hash_exists is True: @@ -464,13 +453,14 @@ async def upload_asset(request: web.Request) -> web.Response: expected_hash=spec.hash, mime_type=spec.mime_type, preview_id=spec.preview_id, + subfolder=spec.subfolder, ) except AssetValidationError as e: delete_temp_file_if_exists(parsed.tmp_path) return _build_error_response(400, e.code, str(e)) except ValueError as e: delete_temp_file_if_exists(parsed.tmp_path) - return _build_error_response(400, "BAD_REQUEST", str(e)) + return _build_error_response(400, "INVALID_BODY", str(e)) except HashMismatchError as e: delete_temp_file_if_exists(parsed.tmp_path) return _build_error_response(400, "HASH_MISMATCH", str(e)) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index af666746d..e588fd63a 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -47,6 +47,7 @@ class ParsedUpload: provided_hash_exists: bool | None provided_mime_type: str | None = None provided_preview_id: str | None = None + provided_subfolder: str | None = None class ListAssetsQuery(BaseModel): @@ -140,7 +141,7 @@ class CreateFromHashBody(BaseModel): if v is None: return [] if isinstance(v, list): - out = [str(t).strip().lower() for t in v if str(t).strip()] + out = [str(t).strip() for t in v if str(t).strip()] seen = set() dedup = [] for t in out: @@ -149,7 +150,7 @@ class CreateFromHashBody(BaseModel): dedup.append(t) return dedup if isinstance(v, str): - return [t.strip().lower() for t in v.split(",") if t.strip()] + return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip())) return [] @@ -206,7 +207,7 @@ class TagsListQuery(BaseModel): if v is None: return v v = v.strip() - return v.lower() or None + return v or None class TagsAdd(BaseModel): @@ -220,7 +221,7 @@ class TagsAdd(BaseModel): for t in v: if not isinstance(t, str): raise TypeError("tags must be strings") - tnorm = t.strip().lower() + tnorm = t.strip() if tnorm: out.append(tnorm) seen = set() @@ -239,8 +240,9 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. - - tags: optional list; if provided, first is root ('models'|'input'|'output'); - if root == 'models', second must be a valid category + - tags: labels plus one destination role ('models'|'input'|'output') for new bytes; + if role == 'models', exactly one model_type: tag is required + - subfolder: optional destination subfolder for new bytes - name: display name - user_metadata: arbitrary JSON object (optional) - hash: optional canonical 'blake3:' for validation / fast-path @@ -258,6 +260,7 @@ class UploadAssetSpec(BaseModel): hash: str | None = Field(default=None) mime_type: str | None = Field(default=None) preview_id: str | None = Field(default=None) # references an asset_reference id + subfolder: str | None = Field(default=None, max_length=1024) @field_validator("hash", mode="before") @classmethod @@ -309,12 +312,20 @@ class UploadAssetSpec(BaseModel): norm = [] seen = set() for t in items: - tnorm = str(t).strip().lower() + tnorm = str(t).strip() if tnorm and tnorm not in seen: seen.add(tnorm) norm.append(tnorm) return norm + @field_validator("subfolder", mode="before") + @classmethod + def _parse_subfolder(cls, v): + if v is None: + return None + s = str(v).strip() + return s or None + @field_validator("user_metadata", mode="before") @classmethod def _parse_metadata_json(cls, v): @@ -335,14 +346,4 @@ class UploadAssetSpec(BaseModel): @model_validator(mode="after") def _validate_order(self): - if not self.tags: - raise ValueError("at least one tag is required for uploads") - root = self.tags[0] - if root not in {"models", "input", "output"}: - raise ValueError("first tag must be one of: models, input, output") - if root == "models": - if len(self.tags) < 2: - raise ValueError( - "models uploads require a category tag as the second tag" - ) return self diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py index 13d3d372c..ce4a807a6 100644 --- a/app/assets/api/upload.py +++ b/app/assets/api/upload.py @@ -54,6 +54,7 @@ async def parse_multipart_upload( provided_hash_exists: bool | None = None provided_mime_type: str | None = None provided_preview_id: str | None = None + provided_subfolder: str | None = None file_written = 0 tmp_path: str | None = None @@ -140,6 +141,8 @@ async def parse_multipart_upload( provided_mime_type = ((await field.text()) or "").strip() or None elif fname == "preview_id": provided_preview_id = ((await field.text()) or "").strip() or None + elif fname == "subfolder": + provided_subfolder = ((await field.text()) or "").strip() or None if not file_present and not (provided_hash and provided_hash_exists): raise UploadError( @@ -166,6 +169,7 @@ async def parse_multipart_upload( provided_hash_exists=provided_hash_exists, provided_mime_type=provided_mime_type, provided_preview_id=provided_preview_id, + provided_subfolder=provided_subfolder, ) diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index d41d73a10..148f34801 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -265,6 +265,8 @@ def list_tags_with_usage( order: str = "count_desc", owner_id: str = "", ) -> tuple[list[tuple[str, str, int]], int]: + prefix_filter = prefix.strip() if prefix else "" + counts_sq = ( select( AssetReferenceTag.tag_name.label("tag_name"), @@ -293,9 +295,8 @@ def list_tags_with_usage( .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) ) - if prefix: - escaped, esc = escape_sql_like_string(prefix.strip().lower()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) + if prefix_filter: + q = q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter) if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) @@ -306,9 +307,8 @@ def list_tags_with_usage( q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_sql_like_string(prefix.strip().lower()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if prefix_filter: + total_q = total_q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter) if not include_zero: visible_tags_sq = ( select(AssetReferenceTag.tag_name) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 3798f3933..87734d0dc 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -41,10 +41,10 @@ def get_utc_now() -> datetime: def normalize_tags(tags: list[str] | None) -> list[str]: """ Normalize a list of tags by: - - Stripping whitespace and converting to lowercase. - - Removing duplicates. + - Stripping whitespace. + - Removing exact duplicates while preserving order and case. """ - return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) + return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip())) def validate_blake3_hash(s: str) -> str: diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index 3b6dc237c..28aac33d5 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -34,6 +34,7 @@ from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.path_utils import ( compute_relative_filename, + get_backend_system_tags_from_path, get_name_and_tags_from_asset_path, resolve_destination_from_tags, validate_path_within_base, @@ -101,7 +102,11 @@ def _ingest_file_from_path( if preview_id and ref.preview_id != preview_id: ref.preview_id = preview_id - norm = normalize_tags(list(tags)) + try: + backend_tags = get_backend_system_tags_from_path(locator) + except ValueError: + backend_tags = [] + norm = normalize_tags([*list(tags), *backend_tags]) if norm: if require_existing_tags: validate_tags_exist(session, norm) @@ -458,6 +463,7 @@ def upload_from_temp_path( expected_hash: str | None = None, mime_type: str | None = None, preview_id: str | None = None, + subfolder: str | None = None, ) -> UploadResult: try: digest, _ = hashing.compute_blake3_hash(temp_path) @@ -474,6 +480,10 @@ def upload_from_temp_path( existing = get_asset_by_hash(session, asset_hash=asset_hash) if existing is not None: + # Once content is already known, duplicate byte uploads are treated as + # reference-only creation. Request tags are labels only here: do not + # require upload destination tags, do not move bytes, and do not + # synthesize path-derived classification or uploaded provenance. with contextlib.suppress(Exception): if temp_path and os.path.exists(temp_path): os.remove(temp_path) @@ -498,7 +508,7 @@ def upload_from_temp_path( if not tags: raise ValueError("tags are required for new asset uploads") - base_dir, subdirs = resolve_destination_from_tags(tags) + base_dir, subdirs = resolve_destination_from_tags(tags, subfolder=subfolder) dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir os.makedirs(dest_dir, exist_ok=True) @@ -535,7 +545,7 @@ def upload_from_temp_path( owner_id=owner_id, preview_id=preview_id, user_metadata=user_metadata or {}, - tags=tags, + tags=[*(tags or []), "uploaded"], tag_origin="manual", require_existing_tags=False, ) @@ -569,15 +579,19 @@ def register_file_in_place( ) -> UploadResult: """Register an already-saved file in the asset database without moving it. - Tags are derived from the filesystem path (root category + subfolder names), - merged with any caller-provided tags, matching the behavior of the scanner. + This helper is used by upload paths that have already written bytes before + registering the file, so it records the same ``uploaded`` tag as the + multipart byte-upload path. + + Tags are derived from trusted filesystem classification and merged with any + caller-provided tags, matching the behavior of the scanner. If the path is not under a known root, only the caller-provided tags are used. """ try: _, path_tags = get_name_and_tags_from_asset_path(abs_path) except ValueError: path_tags = [] - merged_tags = normalize_tags([*path_tags, *tags]) + merged_tags = normalize_tags([*path_tags, *tags, "uploaded"]) try: digest, _ = hashing.compute_blake3_hash(abs_path) diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index 892140ffb..3c64f0bef 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -1,12 +1,11 @@ import os -from pathlib import Path +from pathlib import Path, PureWindowsPath from typing import Literal import folder_paths -from app.assets.helpers import normalize_tags -_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) +_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"}) def get_comfy_models_folders() -> list[tuple[str, list[str]]]: @@ -14,7 +13,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: Includes every category registered in folder_names_and_paths, regardless of whether its paths are under the main models_dir, - but excludes non-model entries like custom_nodes. + but excludes non-model entries like configs and custom_nodes. """ targets: list[tuple[str, list[str]]] = [] for name, values in folder_paths.folder_names_and_paths.items(): @@ -26,36 +25,60 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: return targets -def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: - """Validates and maps tags -> (base_dir, subdirs_for_fs)""" - if not tags: - raise ValueError("tags must not be empty") - root = tags[0].lower() +def _validate_subfolder(subfolder: str | None) -> list[str]: + if not subfolder: + return [] + + if "\\" in subfolder: + raise ValueError("invalid subfolder path") + windows_path = PureWindowsPath(subfolder) + if windows_path.drive or windows_path.root: + raise ValueError("invalid subfolder path") + + parts = Path(subfolder).parts + invalid = {"", ".", ".."} + if Path(subfolder).is_absolute() or any(part in invalid for part in parts): + raise ValueError("invalid subfolder path") + if any("/" in part or "\\" in part for part in parts): + raise ValueError("invalid subfolder path") + return list(parts) + + +def resolve_destination_from_tags( + tags: list[str], subfolder: str | None = None +) -> tuple[str, list[str]]: + """Validates and maps upload routing tags -> (base_dir, subdirs_for_fs). + + The request tags are only used to choose the write destination. Extra tags + remain labels; they do not become path components or trusted classification. + Explicit subfolder is the only request field that can add path components. + """ + destination_roles = [t for t in tags if t in {"input", "models", "output"}] + if len(destination_roles) != 1: + raise ValueError("uploads require exactly one destination role: input, models, or output") + + root = destination_roles[0] if root == "models": - if len(tags) < 2: - raise ValueError("at least two tags required for model asset") + model_type_tags = [t for t in tags if t.startswith("model_type:")] + if len(model_type_tags) != 1: + raise ValueError("models uploads require exactly one model_type: tag") + folder_name = model_type_tags[0].split(":", 1)[1] + if not folder_name: + raise ValueError("models uploads require exactly one model_type: tag") + model_folder_paths = dict(get_comfy_models_folders()) try: - bases = folder_paths.folder_names_and_paths[tags[1]][0] + bases = model_folder_paths[folder_name] except KeyError: - raise ValueError(f"unknown model category '{tags[1]}'") + raise ValueError(f"unknown model category '{folder_name}'") if not bases: - raise ValueError(f"no base path configured for category '{tags[1]}'") + raise ValueError(f"no base path configured for category '{folder_name}'") base_dir = os.path.abspath(bases[0]) - raw_subdirs = tags[2:] elif root == "input": base_dir = os.path.abspath(folder_paths.get_input_directory()) - raw_subdirs = tags[1:] - elif root == "output": - base_dir = os.path.abspath(folder_paths.get_output_directory()) - raw_subdirs = tags[1:] else: - raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") - _sep_chars = frozenset(("/", "\\", os.sep)) - for i in raw_subdirs: - if i in (".", "..") or _sep_chars & set(i): - raise ValueError("invalid path component in tags") + base_dir = os.path.abspath(folder_paths.get_output_directory()) - return base_dir, raw_subdirs if raw_subdirs else [] + return base_dir, _validate_subfolder(subfolder) def validate_path_within_base(candidate: str, base: str) -> None: @@ -156,18 +179,55 @@ def get_asset_category_and_relative_path( ) +def get_backend_system_tags_from_path(path: str) -> list[str]: + """Return trusted backend tags derived from current filesystem facts. + + The returned tags are only the backend-generated system tags: ``models``, + ``model_type:``, ``input``, ``output``, and ``temp``. Model + type tags are based on registered folder names, not path components. + """ + fp_abs = os.path.abspath(path) + fp_path = Path(fp_abs) + tags: list[str] = [] + + def _add(tag: str) -> None: + if tag not in tags: + tags.append(tag) + + for role, base in ( + ("input", folder_paths.get_input_directory()), + ("output", folder_paths.get_output_directory()), + ("temp", folder_paths.get_temp_directory()), + ): + if fp_path.is_relative_to(os.path.abspath(base)): + _add(role) + + model_types: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + for base in bases: + if fp_path.is_relative_to(os.path.abspath(base)): + model_types.append(folder_name) + break + + if model_types: + _add("models") + for folder_name in model_types: + _add(f"model_type:{folder_name}") + + if not tags: + raise ValueError( + f"Path is not within input, output, temp, or configured model bases: {path}" + ) + return tags + + def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: """Return (name, tags) derived from a filesystem path. - name: base filename with extension - - tags: [root_category] + parent folder names in order + - tags: trusted backend classification tags derived from the path Raises: ValueError: path does not belong to any known root. """ - root_category, some_path = get_asset_category_and_relative_path(file_path) - p = Path(some_path) - parent_parts = [ - part for part in p.parent.parts if part not in (".", "..", p.anchor) - ] - return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) + return Path(file_path).name, get_backend_system_tags_from_path(file_path) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0f30608a9..cb14a5be0 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -100,6 +100,7 @@ def _parse_cli_feature_flags() -> dict[str, Any]: # Default server capabilities _CORE_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, + "supports_model_type_tags": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, diff --git a/openapi.yaml b/openapi.yaml index 380e4476e..55163b098 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7,14 +7,6 @@ components: description: Timestamp when the asset was created format: date-time type: string - display_name: - description: Display name of the asset. Mirrors name for backwards compatibility. - nullable: true - type: string - file_path: - description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") - nullable: true - type: string hash: description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ @@ -144,14 +136,6 @@ components: AssetUpdated: description: Response returned when an existing asset is successfully updated. properties: - display_name: - description: Display name of the asset. Mirrors name for backwards compatibility. - nullable: true - type: string - file_path: - description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") - nullable: true - type: string hash: description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ @@ -2454,6 +2438,9 @@ paths: supports_preview_metadata: description: Whether the server supports preview metadata type: boolean + supports_model_type_tags: + description: Whether the server supports namespaced model type asset tags + type: boolean type: object description: Success headers: diff --git a/server.py b/server.py index 361850f38..faf40e501 100644 --- a/server.py +++ b/server.py @@ -440,7 +440,10 @@ class PromptServer(): if args.enable_assets: try: tag = image_upload_type if image_upload_type in ("input", "output") else "input" - result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag]) + tags = [tag] + if subfolder in {"3d", "pasted", "painter", "threed", "webcam"}: + tags.append(subfolder) + result = register_file_in_place(abs_path=filepath, name=filename, tags=tags) resp["asset"] = { "id": result.ref.id, "name": result.ref.name, diff --git a/tests-unit/app_test/test_migrations.py b/tests-unit/app_test/test_migrations.py index fa10c1727..bea72a83b 100644 --- a/tests-unit/app_test/test_migrations.py +++ b/tests-unit/app_test/test_migrations.py @@ -8,6 +8,7 @@ upgrade/downgrade for 0003+. """ import os +import sqlite3 import pytest from alembic import command @@ -30,6 +31,12 @@ def _make_config(db_path: str) -> Config: return cfg +def _sqlite_path(cfg: Config) -> str: + url = cfg.get_main_option("sqlalchemy.url") + assert url is not None and url.startswith("sqlite:///") + return url.removeprefix("sqlite:///") + + @pytest.fixture def migration_db(tmp_path): """Yield an alembic Config pre-upgraded to the baseline revision.""" @@ -55,3 +62,26 @@ def test_upgrade_downgrade_cycle(migration_db): command.upgrade(migration_db, "head") command.downgrade(migration_db, _BASELINE) command.upgrade(migration_db, "head") + + +def test_case_sensitive_tags_downgrade_normalizes_existing_tags(migration_db): + """Downgrading 0005 folds mixed-case tag vocabulary before restoring CHECK.""" + command.upgrade(migration_db, "0005_allow_case_sensitive_tags") + + db_path = _sqlite_path(migration_db) + with sqlite3.connect(db_path) as conn: + conn.execute("INSERT INTO tags(name) VALUES (?)", ("NewTag",)) + conn.execute("INSERT INTO tags(name) VALUES (?)", ("newtag",)) + conn.execute("INSERT INTO tags(name) VALUES (?)", ("model_type:LLM",)) + + command.downgrade(migration_db, "0004_drop_tag_type") + + with sqlite3.connect(db_path) as conn: + tags = {row[0] for row in conn.execute("SELECT name FROM tags")} + assert "newtag" in tags + assert "model_type:llm" in tags + assert "NewTag" not in tags + assert "model_type:LLM" not in tags + + with pytest.raises(sqlite3.IntegrityError): + conn.execute("INSERT INTO tags(name) VALUES (?)", ("Upper",)) diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 4aa20372f..44416e8c5 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas p = getattr(request, "param", {}) or {} tags: Optional[list[str]] = p.get("tags") if tags is None: - tags = ["models", "checkpoints", "unit-tests", "alpha"] + tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"] meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} # Unique content per test so the seed always creates a fresh asset (201). # Delete is now always a soft delete, so content from a prior test survives diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index fe510e342..74dfb8a37 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -133,6 +133,66 @@ class TestListReferencesPage: assert total == 1 assert refs[0].name == "tagged" + def test_include_tags_filter_ands_persisted_model_tags(self, session: Session): + asset = _make_asset(session, "hash-model-tags") + checkpoint = _make_reference(session, asset, name="checkpoint") + lora = _make_reference(session, asset, name="lora") + input_ref = _make_reference(session, asset, name="input") + ensure_tags_exist( + session, + ["models", "model_type:checkpoints", "model_type:loras", "unit-tests"], + ) + add_tags_to_reference( + session, + reference_id=checkpoint.id, + tags=["models", "model_type:checkpoints", "unit-tests"], + origin="automatic", + ) + add_tags_to_reference( + session, + reference_id=lora.id, + tags=["models", "model_type:loras", "unit-tests"], + origin="automatic", + ) + add_tags_to_reference( + session, + reference_id=input_ref.id, + tags=["unit-tests"], + ) + session.commit() + + refs, _, total = list_references_page( + session, + include_tags=["models", "model_type:checkpoints", "unit-tests"], + ) + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_include_tags_filter_preserves_model_type_case(self, session: Session): + asset = _make_asset(session, "hash-model-case") + ref = _make_reference(session, asset, name="llm") + ensure_tags_exist(session, ["models", "model_type:LLM"]) + add_tags_to_reference( + session, + reference_id=ref.id, + tags=["models", "model_type:LLM"], + origin="automatic", + ) + session.commit() + + refs, _, total = list_references_page( + session, include_tags=["models", "model_type:LLM"] + ) + refs_lower, _, total_lower = list_references_page( + session, include_tags=["models", "model_type:llm"] + ) + + assert total == 1 + assert refs[0].id == ref.id + assert total_lower == 0 + assert refs_lower == [] + def test_exclude_tags_filter(self, session: Session): asset = _make_asset(session, "hash1") _make_reference(session, asset, name="keep") diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py index 6222714d1..bc041953a 100644 --- a/tests-unit/assets_test/queries/test_tags.py +++ b/tests-unit/assets_test/queries/test_tags.py @@ -58,7 +58,7 @@ class TestEnsureTagsExist: session.commit() tags = session.query(Tag).all() - assert {t.name for t in tags} == {"alpha", "beta"} + assert {t.name for t in tags} == {"ALPHA", "Beta", "alpha"} def test_empty_list_is_noop(self, session: Session): ensure_tags_exist(session, []) @@ -258,6 +258,16 @@ class TestListTagsWithUsage: tag_names = {name for name, _ in rows} assert tag_names == {"alpha", "alphabet"} + def test_prefix_filter_is_case_sensitive(self, session: Session): + ensure_tags_exist(session, ["model_type:LLM", "model_type:llm"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="model_type:L") + + tag_names = {name for name, _ in rows} + assert tag_names == {"model_type:LLM"} + assert total == 1 + def test_order_by_name(self, session: Session): ensure_tags_exist(session, ["zebra", "alpha", "middle"]) session.commit() diff --git a/tests-unit/assets_test/services/test_path_utils.py b/tests-unit/assets_test/services/test_path_utils.py index 3fa905f9a..fe92896a8 100644 --- a/tests-unit/assets_test/services/test_path_utils.py +++ b/tests-unit/assets_test/services/test_path_utils.py @@ -6,7 +6,11 @@ from unittest.mock import patch import pytest -from app.assets.services.path_utils import get_asset_category_and_relative_path +from app.assets.services.path_utils import ( + get_asset_category_and_relative_path, + get_name_and_tags_from_asset_path, + resolve_destination_from_tags, +) @pytest.fixture @@ -76,6 +80,137 @@ class TestGetAssetCategoryAndRelativePath: cat, rel = get_asset_category_and_relative_path(str(f)) assert cat == "models" + def test_model_path_tags_include_registered_model_type_only(self, fake_dirs): + f = fake_dirs["models"] / "subdir" / "model.safetensors" + f.parent.mkdir() + f.touch() + + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "checkpoints" not in tags + assert "subdir" not in tags + + def test_model_type_preserves_registered_folder_case(self, fake_dirs): + llm_dir = fake_dirs["models"].parent / "LLM" + llm_dir.mkdir() + f = llm_dir / "model.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("LLM", [str(llm_dir)])], + ): + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:LLM" in tags + assert "model_type:llm" not in tags + + def test_path_components_do_not_create_model_type_tags(self, fake_dirs): + f = fake_dirs["models"] / "loras" / "model.safetensors" + f.parent.mkdir() + f.touch() + + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "loras" not in tags + assert "model_type:loras" not in tags + + def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs): + shared_root = fake_dirs["models"].parent / "shared" + shared_root.mkdir() + f = shared_root / "foo.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("checkpoints", [str(shared_root)]), + ("loras", [str(shared_root)]), + ], + ): + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "model_type:loras" in tags + + def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs): + output_checkpoints_dir = fake_dirs["output"] / "checkpoints" + output_checkpoints_dir.mkdir() + f = output_checkpoints_dir / "saved.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(output_checkpoints_dir)])], + ): + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "models" in tags + assert "model_type:checkpoints" in tags + assert "output" in tags + + def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs): + f = fake_dirs["temp"] / "preview.png" + f.touch() + + _name, tags = get_name_and_tags_from_asset_path(str(f)) + + assert "temp" in tags + assert "output" not in tags + assert "preview:true" not in tags + def test_unknown_path_raises(self, fake_dirs): with pytest.raises(ValueError, match="not within"): get_asset_category_and_relative_path("/some/random/path.png") + + +class TestResolveDestinationFromTags: + def test_explicit_subfolder_is_path_component(self, fake_dirs): + base_dir, subdirs = resolve_destination_from_tags( + ["input", "unit-tests", "foo"], subfolder="foo/bar" + ) + + assert base_dir == os.path.abspath(fake_dirs["input"]) + assert subdirs == ["foo", "bar"] + + @pytest.mark.parametrize( + "subfolder", + ["../escape", "foo/../bar", "/abs", "foo\\bar", "C:/escape", "C:escape"], + ) + def test_explicit_subfolder_rejects_unsafe_paths(self, fake_dirs, subfolder: str): + with pytest.raises(ValueError, match="invalid subfolder"): + resolve_destination_from_tags(["input", "unit-tests"], subfolder=subfolder) + + def test_model_upload_rejects_non_writable_registered_folders(self): + with tempfile.TemporaryDirectory() as root: + root_path = Path(root) + checkpoints_dir = root_path / "models" / "checkpoints" + configs_dir = root_path / "models" / "configs" + custom_nodes_dir = root_path / "custom_nodes" + for path in (checkpoints_dir, configs_dir, custom_nodes_dir): + path.mkdir(parents=True) + + with patch("app.assets.services.path_utils.folder_paths") as mock_fp: + mock_fp.folder_names_and_paths = { + "checkpoints": ([str(checkpoints_dir)], set()), + "configs": ([str(configs_dir)], set()), + "custom_nodes": ([str(custom_nodes_dir)], set()), + } + + base_dir, subdirs = resolve_destination_from_tags( + ["models", "model_type:checkpoints"] + ) + assert base_dir == os.path.abspath(checkpoints_dir) + assert subdirs == [] + + for folder_name in ("configs", "custom_nodes"): + with pytest.raises(ValueError, match="unknown model category"): + resolve_destination_from_tags( + ["models", f"model_type:{folder_name}"] + ) diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py index 29ec1d09d..205723650 100644 --- a/tests-unit/assets_test/test_assets_missing_sync.py +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -19,7 +19,8 @@ def test_seed_asset_removed_when_file_is_deleted( """Asset without hash (seed) whose file disappears: after triggering sync_seed_assets, Asset + AssetInfo disappear. """ - # Create a file directly under input/unit-tests/ so tags include "unit-tests" + # Create a file directly under input/unit-tests/. Backend tags only + # classify the root; nested path components are not exposed as tags. case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" case_dir.mkdir(parents=True, exist_ok=True) name = f"seed_{uuid.uuid4().hex[:8]}.bin" @@ -32,7 +33,7 @@ def test_seed_asset_removed_when_file_is_deleted( # Verify it is visible via API and carries no hash (seed) r1 = http.get( api_base + "/api/assets", - params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + params={"include_tags": root, "name_contains": name}, timeout=120, ) body1 = r1.json() @@ -54,7 +55,7 @@ def test_seed_asset_removed_when_file_is_deleted( # It should disappear (AssetInfo and seed Asset gone) r2 = http.get( api_base + "/api/assets", - params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + params={"include_tags": root, "name_contains": name}, timeout=120, ) body2 = r2.json() @@ -132,7 +133,7 @@ def test_hashed_asset_two_asset_infos_both_get_missing( second_id = b2["id"] # Remove the single underlying file - p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") + p = comfy_tmp_base_dir / "input" / get_asset_filename(created["asset_hash"], ".png") assert p.exists() p.unlink() @@ -250,8 +251,7 @@ def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( a = asset_factory(name, [root, "unit-tests", scope], {}, data) aid = a["id"] - base = comfy_tmp_base_dir / root / "unit-tests" / scope - p = base / get_asset_filename(a["asset_hash"], ".bin") + p = comfy_tmp_base_dir / root / get_asset_filename(a["asset_hash"], ".bin") st0 = p.stat() orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index 36abb60ee..9a965bcdf 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -290,7 +290,7 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash( r1 = http.get( api_base + "/api/assets", - params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + params={"include_tags": root, "name_contains": name}, timeout=120, ) body = r1.json() diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index 42c64a5fd..b624a4edc 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -95,7 +95,7 @@ def test_download_chooses_existing_state_and_updates_access_time( assert t1 > t0 -@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) +@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "model_type:checkpoints"]}], indirect=True) def test_download_missing_file_returns_404( http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict ): diff --git a/tests-unit/assets_test/test_list_cursor.py b/tests-unit/assets_test/test_list_cursor.py index a37019fd6..8f4cc8251 100644 --- a/tests-unit/assets_test/test_list_cursor.py +++ b/tests-unit/assets_test/test_list_cursor.py @@ -13,7 +13,7 @@ def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]: for n in names: asset_factory( n, - ["models", "checkpoints", "unit-tests", tag], + ["models", "model_type:checkpoints", "unit-tests", tag], {}, make_asset_bytes(n, size=2048), ) @@ -208,7 +208,7 @@ def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api names = [] for i in range(4): n = f"cursor_{sort_field}_{i:02d}.safetensors" - asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i)) + asset_factory(n, ["models", "model_type:checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i)) names.append(n) params = { diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py index 17bbea5c6..d1cba87b3 100644 --- a/tests-unit/assets_test/test_list_filter.py +++ b/tests-unit/assets_test/test_list_filter.py @@ -11,7 +11,7 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse for n in names: asset_factory( n, - ["models", "checkpoints", "unit-tests", "paging"], + ["models", "model_type:checkpoints", "unit-tests", "paging"], {"epoch": 1}, make_asset_bytes(n, size=2048), ) @@ -45,8 +45,8 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory): - a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) - b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) + a = asset_factory("inc_a.safetensors", ["models", "model_type:checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) + b = asset_factory("inc_b.safetensors", ["models", "model_type:checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) r = http.get( api_base + "/api/assets", @@ -81,7 +81,7 @@ def test_list_assets_include_exclude_and_name_contains(http: requests.Session, a def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-size"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-size"] n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) @@ -108,7 +108,7 @@ def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, mak def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-upd"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-upd"] a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) @@ -131,7 +131,7 @@ def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-access"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-access"] asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) time.sleep(0.02) a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) @@ -154,14 +154,14 @@ def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-include"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-include"] a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) - # CSV + case-insensitive + # CSV tag filters are whitespace-trimmed and case-sensitive. r1 = http.get( api_base + "/api/assets", - params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, + params={"include_tags": "unit-tests,lf-include,alpha"}, timeout=120, ) b1 = r1.json() @@ -196,14 +196,14 @@ def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factor def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-exclude"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-exclude"] a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) - # Exclude uppercase should work + # Exclude filters are case-sensitive. r1 = http.get( api_base + "/api/assets", - params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, + params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "beta"}, timeout=120, ) b1 = r1.json() @@ -225,7 +225,7 @@ def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-name"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-name"] a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) @@ -261,7 +261,7 @@ def test_list_assets_name_contains_case_and_specials(http, api_base, asset_facto def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] + t = ["models", "model_type:checkpoints", "unit-tests", "lf-pagelimits"] asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) @@ -319,7 +319,7 @@ def test_list_assets_name_contains_literal_underscore( - foobar.safetensors (must NOT match) """ scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" - tags = ["models", "checkpoints", "unit-tests", scope] + tags = ["models", "model_type:checkpoints", "unit-tests", scope] a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) diff --git a/tests-unit/assets_test/test_metadata_filters.py b/tests-unit/assets_test/test_metadata_filters.py index 20285a3b3..1864b1eef 100644 --- a/tests-unit/assets_test/test_metadata_filters.py +++ b/tests-unit/assets_test/test_metadata_filters.py @@ -5,7 +5,7 @@ def test_meta_and_across_keys_and_types( http, api_base: str, asset_factory, make_asset_bytes ): name = "mf_and_mix.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-and"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-and"] meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) @@ -41,7 +41,7 @@ def test_meta_and_across_keys_and_types( def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): name = "mf_types.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-types"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-types"] meta = {"epoch": 1, "active": True} asset_factory(name, tags, meta, make_asset_bytes(name)) @@ -95,7 +95,7 @@ def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): name = "mf_list_scalars.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-list"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-list"] meta = {"flags": ["red", "green"]} asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) @@ -134,7 +134,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none( http, api_base, asset_factory, make_asset_bytes ): # a1: key missing; a2: explicit null; a3: concrete value - t = ["models", "checkpoints", "unit-tests", "mf-none"] + t = ["models", "model_type:checkpoints", "unit-tests", "mf-none"] a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) @@ -166,7 +166,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none( def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): name = "mf_nested_json.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-nested"] cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) @@ -197,7 +197,7 @@ def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_as def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): name = "mf_list_objects.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-objlist"] transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) @@ -228,7 +228,7 @@ def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_b def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): name = "mf_keys_unicode.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-keys"] meta = { "weird.key": "v1", "path/like": 7, @@ -259,7 +259,7 @@ def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_ def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): - t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + t = ["models", "model_type:checkpoints", "unit-tests", "mf-zero-bool"] a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) @@ -286,7 +286,7 @@ def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_as def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): name = "mf_mixed_list.safetensors" - tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + tags = ["models", "model_type:checkpoints", "unit-tests", "mf-mixed"] meta = {"mix": ["1", 1, True, None]} asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) @@ -311,7 +311,7 @@ def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, mak def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): # Use a unique scope tag to avoid interference - t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + t = ["models", "model_type:checkpoints", "unit-tests", "mf-unknown-scope"] x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) @@ -340,13 +340,13 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_ # alpha matches epoch=1; beta has epoch=2 a = asset_factory( "mf_tag_alpha.safetensors", - ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + ["models", "model_type:checkpoints", "unit-tests", "mf-tag", "alpha"], {"epoch": 1}, make_asset_bytes("alpha"), ) b = asset_factory( "mf_tag_beta.safetensors", - ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + ["models", "model_type:checkpoints", "unit-tests", "mf-tag", "beta"], {"epoch": 2}, make_asset_bytes("beta"), ) @@ -367,7 +367,7 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_ def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): # Three assets in same scope with different sizes and a common filter key - t = ["models", "checkpoints", "unit-tests", "mf-sort"] + t = ["models", "model_type:checkpoints", "unit-tests", "mf-sort"] n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py index 1fbd4d4e2..618ec6c8d 100644 --- a/tests-unit/assets_test/test_prune_orphaned_assets.py +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -29,7 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path): def find_asset(http: requests.Session, api_base: str): """Query API for assets matching scope and optional name.""" def _find(scope: str, name: str | None = None) -> list[dict]: - params = {"include_tags": f"unit-tests,{scope}"} + params = {"limit": "500"} if name: params["name_contains"] = name r = http.get(f"{api_base}/api/assets", params=params, timeout=120) @@ -91,7 +91,7 @@ def test_hashed_asset_not_pruned_when_file_missing( data = make_asset_bytes("test", 2048) a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data) - path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin") + path = comfy_tmp_base_dir / "input" / get_asset_filename(a["asset_hash"], ".bin") path.unlink() trigger_sync_seed_assets(http, api_base) @@ -108,18 +108,20 @@ def test_prune_across_multiple_roots( ): """Prune correctly handles assets across input and output roots.""" scope = f"multi-{uuid.uuid4().hex[:6]}" - input_fp = create_seed_file("input", scope, "input.bin") - create_seed_file("output", scope, "output.bin") + input_name = f"{scope}-input.bin" + output_name = f"{scope}-output.bin" + input_fp = create_seed_file("input", scope, input_name) + create_seed_file("output", scope, output_name) trigger_sync_seed_assets(http, api_base) - assert len(find_asset(scope)) == 2 + assert find_asset(scope, input_name) + assert find_asset(scope, output_name) input_fp.unlink() trigger_sync_seed_assets(http, api_base) - remaining = find_asset(scope) - assert len(remaining) == 1 - assert remaining[0]["name"] == "output.bin" + assert not find_asset(scope, input_name) + assert find_asset(scope, output_name) @pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"]) diff --git a/tests-unit/assets_test/test_tags_api.py b/tests-unit/assets_test/test_tags_api.py index 9729b7d03..93786696f 100644 --- a/tests-unit/assets_test/test_tags_api.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict) body1 = r1.json() assert r1.status_code == 200 names = [t["name"] for t in body1["tags"]] - # A few system tags from migration should exist: + # A few selected contract tags should exist. assert "models" in names - assert "checkpoints" in names + assert "model_type:checkpoints" in names # Only used tags before we add anything new from this test cycle r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120) @@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict) # We already seeded one asset via fixture, so used tags must be non-empty used_names = [t["name"] for t in body2["tags"]] assert "models" in used_names - assert "checkpoints" in used_names + assert "model_type:checkpoints" in used_names # Prefix filter should refine the list r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120) @@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, body1 = r1.json() assert r1.status_code == 200 names = [t["name"] for t in body1["tags"]] - assert "models" in names and "checkpoints" in names + assert "models" in names and "model_type:checkpoints" in names # Create a short-lived asset under input with a unique custom tag scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" @@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict): aid = seeded_asset["id"] - # Add tags with duplicates and mixed case - payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + # Add tags with duplicates while preserving source case. + payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]} r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120) b1 = r1.json() assert r1.status_code == 200, b1 - # normalized, deduplicated; 'unit-tests' was already present from the seed - assert set(b1["added"]) == {"newtag", "beta"} + # stripped, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"NewTag", "BETA"} assert set(b1["already_present"]) == {"unit-tests"} - assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"] rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) g = rg.json() assert rg.status_code == 200 tags_now = set(g["tags"]) - assert {"newtag", "beta"}.issubset(tags_now) + assert {"NewTag", "BETA"}.issubset(tags_now) # Remove a tag and a non-existent tag - payload_del = {"tags": ["newtag", "does-not-exist"]} + payload_del = {"tags": ["NewTag", "does-not-exist"]} r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120) b2 = r2.json() assert r2.status_code == 200 - assert set(b2["removed"]) == {"newtag"} + assert set(b2["removed"]) == {"NewTag"} assert set(b2["not_present"]) == {"does-not-exist"} # Verify remaining tags after deletion @@ -118,8 +118,44 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset g2 = rg2.json() assert rg2.status_code == 200 tags_later = set(g2["tags"]) - assert "newtag" not in tags_later - assert "beta" in tags_later # still present + assert "NewTag" not in tags_later + assert "BETA" in tags_later # still present + + +def test_add_system_looking_tags_allowed_as_labels( + http: requests.Session, api_base: str, seeded_asset: dict +): + aid = seeded_asset["id"] + + response = http.post( + f"{api_base}/api/assets/{aid}/tags", + json={ + "tags": [ + "models", + "model_type:manual", + "model:true", + "models:foo", + "input:true", + "output:true", + "uploaded:true", + "temp:true", + "temporary", + ] + }, + timeout=120, + ) + body = response.json() + + assert response.status_code == 200, body + assert "models" in body["total_tags"] + assert "model_type:manual" in body["total_tags"] + assert "model:true" in body["total_tags"] + assert "models:foo" in body["total_tags"] + assert "input:true" in body["total_tags"] + assert "output:true" in body["total_tags"] + assert "uploaded:true" in body["total_tags"] + assert "temp:true" in body["total_tags"] + assert "temporary" in body["total_tags"] def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict): diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index 427a417cc..cb7f6cd30 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -1,11 +1,13 @@ import json import uuid from concurrent.futures import ThreadPoolExecutor +from pathlib import Path import requests import pytest from app.assets.api.schemas_out import Asset, AssetCreated +from helpers import get_asset_filename def test_asset_created_inherits_hash_field(): @@ -22,7 +24,7 @@ def test_asset_created_inherits_hash_field(): def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes): name = "dup_a.safetensors" - tags = ["models", "checkpoints", "unit-tests", "alpha"] + tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"] meta = {"purpose": "dup"} data = make_asset_bytes(name) files = {"file": (name, data, "application/octet-stream")} @@ -58,7 +60,7 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): # Seed a small file first name = "fastpath_seed.safetensors" - tags = ["models", "checkpoints", "unit-tests"] + tags = ["input", "unit-tests"] meta = {} files = {"file": (name, b"B" * 1024, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} @@ -69,9 +71,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_ assert b1["hash"] == h # Now POST /api/assets with only hash and no file + hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"] files = [ ("hash", (None, h)), - ("tags", (None, json.dumps(tags))), + ("tags", (None, json.dumps(hash_only_tags))), ("name", (None, "fastpath_copy.safetensors")), ("user_metadata", (None, json.dumps({"purpose": "copy"}))), ] @@ -81,6 +84,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_ assert b2["created_new"] is False assert b2["asset_hash"] == h assert b2["hash"] == h + assert "models" in b2["tags"] + assert "checkpoints" in b2["tags"] + assert "uploaded" not in b2["tags"] + assert not any(tag.startswith("model_type:") for tag in b2["tags"]) def test_upload_fastpath_with_known_hash_and_file( @@ -88,7 +95,7 @@ def test_upload_fastpath_with_known_hash_and_file( ): # Seed files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) b1 = r1.json() assert r1.status_code == 201, b1 @@ -104,11 +111,47 @@ def test_upload_fastpath_with_known_hash_and_file( assert b2["created_new"] is False assert b2["asset_hash"] == h assert b2["hash"] == h + assert "checkpoints" in b2["tags"] + assert "uploaded" not in b2["tags"] + assert not any(tag == "model_type:checkpoints" for tag in b2["tags"]) + + +def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination( + http: requests.Session, api_base: str +): + data = b"duplicate-reference-only" * 64 + seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")} + seed_form = { + "tags": json.dumps(["input", "unit-tests", "duplicate-seed"]), + "name": "duplicate-seed.bin", + "user_metadata": json.dumps({}), + } + seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120) + seed = seed_response.json() + assert seed_response.status_code == 201, seed + + duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")} + duplicate_form = { + "tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]), + "name": "duplicate-copy.bin", + "user_metadata": json.dumps({}), + } + duplicate_response = http.post( + api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120 + ) + duplicate = duplicate_response.json() + + assert duplicate_response.status_code == 200, duplicate + assert duplicate["created_new"] is False + assert duplicate["asset_hash"] == seed["asset_hash"] + assert "not-a-destination" in duplicate["tags"] + assert "uploaded" not in duplicate["tags"] + assert "input" not in duplicate["tags"] def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str): data = [ - ("tags", "models,checkpoints"), + ("tags", "models,model_type:checkpoints"), ("tags", json.dumps(["unit-tests", "alpha"])), ("name", "merge.safetensors"), ("user_metadata", json.dumps({"u": 1})), @@ -124,7 +167,7 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base detail = rg.json() assert rg.status_code == 200, detail tags = set(detail["tags"]) - assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags) @pytest.mark.parametrize("root", ["input", "output"]) @@ -192,16 +235,55 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str): assert body["error"]["code"] == "ASSET_NOT_FOUND" +def test_create_from_hash_accepts_arbitrary_system_looking_tags( + http: requests.Session, api_base: str +): + files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")} + form = { + "tags": json.dumps(["input", "unit-tests", "hash-seed"]), + "name": "hash-seed.bin", + "user_metadata": json.dumps({}), + } + seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + seed = seed_response.json() + assert seed_response.status_code == 201, seed + + response = http.post( + api_base + "/api/assets/from-hash", + json={ + "hash": seed["asset_hash"], + "name": "hash-copy.bin", + "tags": [ + "models", + "model:true", + "models:foo", + "temporary:true", + "unit-tests", + "hash-copy", + ], + }, + timeout=120, + ) + body = response.json() + + assert response.status_code == 201, body + assert "models" in body["tags"] + assert "model:true" in body["tags"] + assert "models:foo" in body["tags"] + assert "temporary:true" in body["tags"] + assert "uploaded" not in body["tags"] + + def test_upload_zero_byte_rejected(http: requests.Session, api_base: str): files = {"file": ("empty.safetensors", b"", "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() assert r.status_code == 400 assert body["error"]["code"] == "EMPTY_UPLOAD" -def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str): +def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str): files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")} form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) @@ -212,7 +294,7 @@ def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str) def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str): files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() assert r.status_code == 400 @@ -228,7 +310,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str): def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): files = [ - ("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))), + ("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))), ("name", (None, "x.safetensors")), ] r = http.post(api_base + "/api/assets", files=files, timeout=120) @@ -237,17 +319,33 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): assert body["error"]["code"] == "MISSING_FILE" -def test_upload_models_unknown_category(http: requests.Session, api_base: str): +def test_upload_models_unknown_model_type(http: requests.Session, api_base: str): files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")} - form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"} + form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() - assert r.status_code == 400 + assert r.status_code == 400, body assert body["error"]["code"] == "INVALID_BODY" - assert body["error"]["message"].startswith("unknown models category") -def test_upload_models_requires_category(http: requests.Session, api_base: str): +@pytest.mark.parametrize("model_type", ["configs", "custom_nodes"]) +def test_upload_models_rejects_non_model_registered_folder( + model_type: str, http: requests.Session, api_base: str +): + files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")} + form = { + "tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]), + "name": "not-a-model.py", + } + + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_models_requires_model_type(http: requests.Session, api_base: str): files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")} form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) @@ -256,13 +354,149 @@ def test_upload_models_requires_category(http: requests.Session, api_base: str): assert body["error"]["code"] == "INVALID_BODY" -def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): +def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str): files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")} - form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} + form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) body = r.json() - assert r.status_code == 400 - assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + assert r.status_code == 201, body + assert ".." in body["tags"] + assert "zzz" in body["tags"] + assert "models" in body["tags"] + assert "model_type:checkpoints" in body["tags"] + + +def test_upload_subfolder_is_explicit_path_component( + http: requests.Session, api_base: str, comfy_tmp_base_dir: Path +): + files = {"file": ("nested.bin", b"nested" * 64, "application/octet-stream")} + form = { + "tags": json.dumps(["input", "unit-tests", "foo"]), + "subfolder": "foo/bar", + "name": "nested.bin", + } + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + + assert r.status_code == 201, body + stored_name = get_asset_filename(body["asset_hash"], ".bin") + assert (comfy_tmp_base_dir / "input" / "foo" / "bar" / stored_name).exists() + assert "foo" in body["tags"] + + +def test_upload_rejects_unsafe_subfolder(http: requests.Session, api_base: str): + files = {"file": ("escape.bin", b"escape" * 64, "application/octet-stream")} + form = { + "tags": json.dumps(["input", "unit-tests"]), + "subfolder": "../escape", + "name": "escape.bin", + } + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + + assert r.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +def test_multipart_upload_accepts_system_looking_extra_labels( + http: requests.Session, api_base: str +): + files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")} + form = { + "tags": json.dumps( + [ + "input", + "unit-tests", + "model:true", + "models:foo", + "temporary", + "uploaded:true", + ] + ), + "name": "relaxed-labels.bin", + "user_metadata": json.dumps({}), + } + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 201, body + assert "input" in body["tags"] + assert "model:true" in body["tags"] + assert "models:foo" in body["tags"] + assert "temporary" in body["tags"] + assert "uploaded:true" in body["tags"] + + +def test_multipart_upload_rejects_ambiguous_destination_roles( + http: requests.Session, api_base: str +): + files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")} + form = { + "tags": json.dumps(["input", "output", "unit-tests"]), + "name": "ambiguous.bin", + "user_metadata": json.dumps({}), + } + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +def test_multipart_upload_rejects_multiple_model_types_for_models_destination( + http: requests.Session, api_base: str +): + files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")} + form = { + "tags": json.dumps( + ["models", "model_type:checkpoints", "model_type:loras", "unit-tests"] + ), + "name": "ambiguous-model.safetensors", + "user_metadata": json.dumps({}), + } + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 400, body + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.parametrize( + ("tags", "expected_root", "extension"), + [ + (["input", "unit-tests", "upload-location-input"], "input", ".bin"), + (["output", "unit-tests", "upload-location-output"], "output", ".bin"), + ( + ["models", "model_type:checkpoints", "unit-tests", "upload-location-model"], + "models/checkpoints", + ".safetensors", + ), + ], +) +def test_multipart_upload_role_selects_write_location( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + tags: list[str], + expected_root: str, + extension: str, +): + role = next(tag for tag in tags if tag in {"input", "models", "output"}) + name = f"{role}-role-upload{extension}" + files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")} + form = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps({}), + } + + response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = response.json() + + assert response.status_code == 201, body + stored_name = get_asset_filename(body["asset_hash"], extension) + expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name + assert expected_disk_path.exists() def test_upload_empty_tags_rejected(http: requests.Session, api_base: str): diff --git a/tests-unit/feature_flags_test.py b/tests-unit/feature_flags_test.py index 8ec52a124..a436ab1ec 100644 --- a/tests-unit/feature_flags_test.py +++ b/tests-unit/feature_flags_test.py @@ -29,6 +29,8 @@ class TestFeatureFlags: features = get_server_features() assert "supports_preview_metadata" in features assert features["supports_preview_metadata"] is True + assert "supports_model_type_tags" in features + assert features["supports_model_type_tags"] is True assert "max_upload_size" in features assert isinstance(features["max_upload_size"], (int, float)) diff --git a/tests-unit/websocket_feature_flags_test.py b/tests-unit/websocket_feature_flags_test.py index e93b2e1dd..4950bd9d0 100644 --- a/tests-unit/websocket_feature_flags_test.py +++ b/tests-unit/websocket_feature_flags_test.py @@ -12,6 +12,8 @@ class TestWebSocketFeatureFlags: # Check expected server features assert "supports_preview_metadata" in features assert features["supports_preview_metadata"] is True + assert "supports_model_type_tags" in features + assert features["supports_model_type_tags"] is True assert "max_upload_size" in features assert isinstance(features["max_upload_size"], (int, float)) @@ -75,3 +77,5 @@ class TestWebSocketFeatureFlags: assert server_message["type"] == "feature_flags" assert "supports_preview_metadata" in server_message["data"] assert server_message["data"]["supports_preview_metadata"] is True + assert "supports_model_type_tags" in server_message["data"] + assert server_message["data"]["supports_model_type_tags"] is True