diff --git a/alembic_db/versions/0005_allow_case_sensitive_tags.py b/alembic_db/versions/0005_allow_case_sensitive_tags.py index 8646b350a..bd5f864db 100644 --- a/alembic_db/versions/0005_allow_case_sensitive_tags.py +++ b/alembic_db/versions/0005_allow_case_sensitive_tags.py @@ -6,6 +6,7 @@ Revises: 0004_drop_tag_type Create Date: 2026-06-16 """ +import sqlalchemy as sa from alembic import op revision = "0005_allow_case_sensitive_tags" @@ -21,16 +22,18 @@ def upgrade() -> None: # vocabulary table without the lowercase constraint while preserving # existing tag names. op.execute("PRAGMA foreign_keys=OFF") - op.execute( - "CREATE TABLE tags_new (" - "name VARCHAR(512) NOT NULL, " - "CONSTRAINT pk_tags PRIMARY KEY (name)" - ")" - ) - op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") - op.execute("DROP TABLE tags") - op.execute("ALTER TABLE tags_new RENAME TO tags") - op.execute("PRAGMA foreign_keys=ON") + try: + op.execute( + "CREATE TABLE tags_new (" + "name VARCHAR(512) NOT NULL, " + "CONSTRAINT pk_tags PRIMARY KEY (name)" + ")" + ) + op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") + op.execute("DROP TABLE tags") + op.execute("ALTER TABLE tags_new RENAME TO tags") + finally: + op.execute("PRAGMA foreign_keys=ON") return op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check") @@ -39,31 +42,64 @@ def upgrade() -> None: def downgrade() -> None: # Existing mixed-case tags cannot satisfy the old constraint. Lowercase them # before restoring it, merging duplicate vocabulary/link rows that collide. - op.execute("INSERT OR IGNORE INTO tags(name) SELECT lower(name) FROM tags") - op.execute( - "DELETE FROM asset_reference_tags " - "WHERE rowid NOT IN (" - " SELECT MIN(rowid) FROM asset_reference_tags " - " GROUP BY asset_reference_id, lower(tag_name)" - ")" - ) - op.execute("UPDATE asset_reference_tags SET tag_name = lower(tag_name)") + bind = op.get_bind() + + tag_names = [row[0] for row in bind.execute(sa.text("SELECT name FROM tags"))] + existing_names = set(tag_names) + lowercase_names = sorted({name.lower() for name in tag_names}) + missing_lowercase_rows = [ + {"name": name} for name in lowercase_names if name not in existing_names + ] + if missing_lowercase_rows: + bind.execute(sa.text("INSERT INTO tags(name) VALUES (:name)"), missing_lowercase_rows) + + link_rows = bind.execute( + sa.text( + "SELECT asset_reference_id, tag_name, origin, added_at " + "FROM asset_reference_tags " + "ORDER BY asset_reference_id, tag_name" + ) + ).mappings() + deduped_links = {} + for row in link_rows: + key = (row["asset_reference_id"], row["tag_name"].lower()) + deduped_links.setdefault( + key, + { + "asset_reference_id": row["asset_reference_id"], + "tag_name": row["tag_name"].lower(), + "origin": row["origin"], + "added_at": row["added_at"], + }, + ) + + op.execute("DELETE FROM asset_reference_tags") + if deduped_links: + bind.execute( + sa.text( + "INSERT INTO asset_reference_tags " + "(asset_reference_id, tag_name, origin, added_at) " + "VALUES (:asset_reference_id, :tag_name, :origin, :added_at)" + ), + list(deduped_links.values()), + ) op.execute("DELETE FROM tags WHERE name != lower(name)") - bind = op.get_bind() if bind.dialect.name == "sqlite": op.execute("PRAGMA foreign_keys=OFF") - op.execute( - "CREATE TABLE tags_new (" - "name VARCHAR(512) NOT NULL, " - "CONSTRAINT pk_tags PRIMARY KEY (name), " - "CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))" - ")" - ) - op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") - op.execute("DROP TABLE tags") - op.execute("ALTER TABLE tags_new RENAME TO tags") - op.execute("PRAGMA foreign_keys=ON") + try: + op.execute( + "CREATE TABLE tags_new (" + "name VARCHAR(512) NOT NULL, " + "CONSTRAINT pk_tags PRIMARY KEY (name), " + "CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))" + ")" + ) + op.execute("INSERT INTO tags_new(name) SELECT name FROM tags") + op.execute("DROP TABLE tags") + op.execute("ALTER TABLE tags_new RENAME TO tags") + finally: + op.execute("PRAGMA foreign_keys=ON") return op.create_check_constraint( diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 6e041d637..148f34801 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -265,6 +265,8 @@ def list_tags_with_usage( order: str = "count_desc", owner_id: str = "", ) -> tuple[list[tuple[str, str, int]], int]: + prefix_filter = prefix.strip() if prefix else "" + counts_sq = ( select( AssetReferenceTag.tag_name.label("tag_name"), @@ -293,9 +295,8 @@ def list_tags_with_usage( .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) ) - if prefix: - escaped, esc = escape_sql_like_string(prefix.strip()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) + if prefix_filter: + q = q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter) if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) @@ -306,9 +307,8 @@ def list_tags_with_usage( q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_sql_like_string(prefix.strip()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if prefix_filter: + total_q = total_q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter) if not include_zero: visible_tags_sq = ( select(AssetReferenceTag.tag_name) diff --git a/tests-unit/app_test/test_migrations.py b/tests-unit/app_test/test_migrations.py index fa10c1727..bea72a83b 100644 --- a/tests-unit/app_test/test_migrations.py +++ b/tests-unit/app_test/test_migrations.py @@ -8,6 +8,7 @@ upgrade/downgrade for 0003+. """ import os +import sqlite3 import pytest from alembic import command @@ -30,6 +31,12 @@ def _make_config(db_path: str) -> Config: return cfg +def _sqlite_path(cfg: Config) -> str: + url = cfg.get_main_option("sqlalchemy.url") + assert url is not None and url.startswith("sqlite:///") + return url.removeprefix("sqlite:///") + + @pytest.fixture def migration_db(tmp_path): """Yield an alembic Config pre-upgraded to the baseline revision.""" @@ -55,3 +62,26 @@ def test_upgrade_downgrade_cycle(migration_db): command.upgrade(migration_db, "head") command.downgrade(migration_db, _BASELINE) command.upgrade(migration_db, "head") + + +def test_case_sensitive_tags_downgrade_normalizes_existing_tags(migration_db): + """Downgrading 0005 folds mixed-case tag vocabulary before restoring CHECK.""" + command.upgrade(migration_db, "0005_allow_case_sensitive_tags") + + db_path = _sqlite_path(migration_db) + with sqlite3.connect(db_path) as conn: + conn.execute("INSERT INTO tags(name) VALUES (?)", ("NewTag",)) + conn.execute("INSERT INTO tags(name) VALUES (?)", ("newtag",)) + conn.execute("INSERT INTO tags(name) VALUES (?)", ("model_type:LLM",)) + + command.downgrade(migration_db, "0004_drop_tag_type") + + with sqlite3.connect(db_path) as conn: + tags = {row[0] for row in conn.execute("SELECT name FROM tags")} + assert "newtag" in tags + assert "model_type:llm" in tags + assert "NewTag" not in tags + assert "model_type:LLM" not in tags + + with pytest.raises(sqlite3.IntegrityError): + conn.execute("INSERT INTO tags(name) VALUES (?)", ("Upper",)) diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py index d3634b51d..bc041953a 100644 --- a/tests-unit/assets_test/queries/test_tags.py +++ b/tests-unit/assets_test/queries/test_tags.py @@ -258,6 +258,16 @@ class TestListTagsWithUsage: tag_names = {name for name, _ in rows} assert tag_names == {"alpha", "alphabet"} + def test_prefix_filter_is_case_sensitive(self, session: Session): + ensure_tags_exist(session, ["model_type:LLM", "model_type:llm"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="model_type:L") + + tag_names = {name for name, _ in rows} + assert tag_names == {"model_type:LLM"} + assert total == 1 + def test_order_by_name(self, session: Session): ensure_tags_exist(session, ["zebra", "alpha", "middle"]) session.commit() diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py index 7ab74ce2d..618ec6c8d 100644 --- a/tests-unit/assets_test/test_prune_orphaned_assets.py +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -108,18 +108,20 @@ def test_prune_across_multiple_roots( ): """Prune correctly handles assets across input and output roots.""" scope = f"multi-{uuid.uuid4().hex[:6]}" - input_fp = create_seed_file("input", scope, "input.bin") - create_seed_file("output", scope, "output.bin") + input_name = f"{scope}-input.bin" + output_name = f"{scope}-output.bin" + input_fp = create_seed_file("input", scope, input_name) + create_seed_file("output", scope, output_name) trigger_sync_seed_assets(http, api_base) - assert find_asset(scope, input_fp.name) - assert find_asset(scope, "output.bin") + assert find_asset(scope, input_name) + assert find_asset(scope, output_name) input_fp.unlink() trigger_sync_seed_assets(http, api_base) - assert not find_asset(scope, input_fp.name) - assert find_asset(scope, "output.bin") + assert not find_asset(scope, input_name) + assert find_asset(scope, output_name) @pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])