mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-21 23:39:35 +08:00
Add namespaced model type asset tags
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
e00b55631a
commit
e163d59508
71
alembic_db/versions/0005_allow_case_sensitive_tags.py
Normal file
71
alembic_db/versions/0005_allow_case_sensitive_tags.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Allow case-sensitive tag names.
|
||||
|
||||
Revision ID: 0005_allow_case_sensitive_tags
|
||||
Revises: 0004_drop_tag_type
|
||||
Create Date: 2026-06-16
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "0005_allow_case_sensitive_tags"
|
||||
down_revision = "0004_drop_tag_type"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
if bind.dialect.name == "sqlite":
|
||||
# SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag
|
||||
# vocabulary table without the lowercase constraint while preserving
|
||||
# existing tag names.
|
||||
op.execute("PRAGMA foreign_keys=OFF")
|
||||
op.execute(
|
||||
"CREATE TABLE tags_new ("
|
||||
"name VARCHAR(512) NOT NULL, "
|
||||
"CONSTRAINT pk_tags PRIMARY KEY (name)"
|
||||
")"
|
||||
)
|
||||
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
|
||||
op.execute("DROP TABLE tags")
|
||||
op.execute("ALTER TABLE tags_new RENAME TO tags")
|
||||
op.execute("PRAGMA foreign_keys=ON")
|
||||
return
|
||||
|
||||
op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Existing mixed-case tags cannot satisfy the old constraint. Lowercase them
|
||||
# before restoring it, merging duplicate vocabulary/link rows that collide.
|
||||
op.execute("INSERT OR IGNORE INTO tags(name) SELECT lower(name) FROM tags")
|
||||
op.execute(
|
||||
"DELETE FROM asset_reference_tags "
|
||||
"WHERE rowid NOT IN ("
|
||||
" SELECT MIN(rowid) FROM asset_reference_tags "
|
||||
" GROUP BY asset_reference_id, lower(tag_name)"
|
||||
")"
|
||||
)
|
||||
op.execute("UPDATE asset_reference_tags SET tag_name = lower(tag_name)")
|
||||
op.execute("DELETE FROM tags WHERE name != lower(name)")
|
||||
|
||||
bind = op.get_bind()
|
||||
if bind.dialect.name == "sqlite":
|
||||
op.execute("PRAGMA foreign_keys=OFF")
|
||||
op.execute(
|
||||
"CREATE TABLE tags_new ("
|
||||
"name VARCHAR(512) NOT NULL, "
|
||||
"CONSTRAINT pk_tags PRIMARY KEY (name), "
|
||||
"CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))"
|
||||
")"
|
||||
)
|
||||
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
|
||||
op.execute("DROP TABLE tags")
|
||||
op.execute("ALTER TABLE tags_new RENAME TO tags")
|
||||
op.execute("PRAGMA foreign_keys=ON")
|
||||
return
|
||||
|
||||
op.create_check_constraint(
|
||||
"ck_tags_ck_tags_lowercase", "tags", "name = lower(name)"
|
||||
)
|
||||
@ -10,7 +10,6 @@ from typing import Any
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
import folder_paths
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
from app.assets.services import schemas
|
||||
@ -40,6 +39,7 @@ from app.assets.services import (
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.path_utils import compute_api_file_path
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -169,6 +169,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
asset_hash=asset_content_hash,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
file_path=compute_api_file_path(result.ref.file_path),
|
||||
tags=result.tags,
|
||||
preview_url=preview_url,
|
||||
preview_id=result.ref.preview_id,
|
||||
@ -416,17 +417,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
|
||||
)
|
||||
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if (
|
||||
len(spec.tags) < 2
|
||||
or spec.tags[1] not in folder_paths.folder_names_and_paths
|
||||
):
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
category = spec.tags[1] if len(spec.tags) >= 2 else ""
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{category}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Fast path: hash exists, create AssetReference without writing anything
|
||||
if spec.hash and parsed.provided_hash_exists is True:
|
||||
@ -470,7 +460,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
return _build_error_response(400, e.code, str(e))
|
||||
except ValueError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, "BAD_REQUEST", str(e))
|
||||
return _build_error_response(400, "INVALID_BODY", str(e))
|
||||
except HashMismatchError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, "HASH_MISMATCH", str(e))
|
||||
|
||||
@ -140,7 +140,7 @@ class CreateFromHashBody(BaseModel):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
out = [str(t).strip() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
@ -149,7 +149,7 @@ class CreateFromHashBody(BaseModel):
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip()))
|
||||
return []
|
||||
|
||||
|
||||
@ -206,7 +206,7 @@ class TagsListQuery(BaseModel):
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
return v or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
@ -220,7 +220,7 @@ class TagsAdd(BaseModel):
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
tnorm = t.strip()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
@ -309,7 +309,7 @@ class UploadAssetSpec(BaseModel):
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
tnorm = str(t).strip()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
@ -335,14 +335,4 @@ class UploadAssetSpec(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("at least one tag is required for uploads")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError(
|
||||
"models uploads require a category tag as the second tag"
|
||||
)
|
||||
return self
|
||||
|
||||
@ -14,6 +14,7 @@ class Asset(BaseModel):
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
file_path: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@ -294,7 +294,7 @@ def list_tags_with_usage(
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
escaped, esc = escape_sql_like_string(prefix.strip())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
@ -307,7 +307,7 @@ def list_tags_with_usage(
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
escaped, esc = escape_sql_like_string(prefix.strip())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
visible_tags_sq = (
|
||||
|
||||
@ -41,10 +41,10 @@ def get_utc_now() -> datetime:
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize a list of tags by:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
- Stripping whitespace.
|
||||
- Removing exact duplicates while preserving order and case.
|
||||
"""
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
|
||||
@ -34,6 +34,7 @@ from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_backend_system_tags_from_path,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
@ -101,7 +102,11 @@ def _ingest_file_from_path(
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
try:
|
||||
backend_tags = get_backend_system_tags_from_path(locator)
|
||||
except ValueError:
|
||||
backend_tags = []
|
||||
norm = normalize_tags([*list(tags), *backend_tags])
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
validate_tags_exist(session, norm)
|
||||
@ -474,6 +479,10 @@ def upload_from_temp_path(
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
if existing is not None:
|
||||
# Once content is already known, duplicate byte uploads are treated as
|
||||
# reference-only creation. Request tags are labels only here: do not
|
||||
# require upload destination tags, do not move bytes, and do not
|
||||
# synthesize path-derived classification or uploaded provenance.
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
@ -535,7 +544,7 @@ def upload_from_temp_path(
|
||||
owner_id=owner_id,
|
||||
preview_id=preview_id,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tags=[*(tags or []), "uploaded"],
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
@ -569,15 +578,19 @@ def register_file_in_place(
|
||||
) -> UploadResult:
|
||||
"""Register an already-saved file in the asset database without moving it.
|
||||
|
||||
Tags are derived from the filesystem path (root category + subfolder names),
|
||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||
This helper is used by upload paths that have already written bytes before
|
||||
registering the file, so it records the same ``uploaded`` tag as the
|
||||
multipart byte-upload path.
|
||||
|
||||
Tags are derived from trusted filesystem classification and merged with any
|
||||
caller-provided tags, matching the behavior of the scanner.
|
||||
If the path is not under a known root, only the caller-provided tags are used.
|
||||
"""
|
||||
try:
|
||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
except ValueError:
|
||||
path_tags = []
|
||||
merged_tags = normalize_tags([*path_tags, *tags])
|
||||
merged_tags = normalize_tags([*path_tags, *tags, "uploaded"])
|
||||
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||
|
||||
@ -3,10 +3,9 @@ from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
@ -14,7 +13,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
|
||||
Includes every category registered in folder_names_and_paths,
|
||||
regardless of whether its paths are under the main models_dir,
|
||||
but excludes non-model entries like custom_nodes.
|
||||
but excludes non-model entries like configs and custom_nodes.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
@ -27,35 +26,37 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
"""Validates and maps upload routing tags -> (base_dir, subdirs_for_fs).
|
||||
|
||||
The request tags are only used to choose the write destination. Extra tags
|
||||
remain labels; they do not become path components or trusted classification.
|
||||
"""
|
||||
destination_roles = [t for t in tags if t in {"input", "models", "output"}]
|
||||
if len(destination_roles) != 1:
|
||||
raise ValueError("uploads require exactly one destination role: input, models, or output")
|
||||
|
||||
root = destination_roles[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
model_type_tags = [t for t in tags if t.startswith("model_type:")]
|
||||
if len(model_type_tags) != 1:
|
||||
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
|
||||
folder_name = model_type_tags[0].split(":", 1)[1]
|
||||
if not folder_name:
|
||||
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
|
||||
model_folder_paths = dict(get_comfy_models_folders())
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
bases = model_folder_paths[folder_name]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
raise ValueError(f"unknown model category '{folder_name}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
raise ValueError(f"no base path configured for category '{folder_name}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
for i in raw_subdirs:
|
||||
if i in (".", "..") or _sep_chars & set(i):
|
||||
raise ValueError("invalid path component in tags")
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
return base_dir, []
|
||||
|
||||
|
||||
def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
@ -91,6 +92,25 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def compute_api_file_path(file_path: str | None) -> str | None:
|
||||
"""Return a stable API-visible path relative to a known asset root.
|
||||
|
||||
Examples:
|
||||
/.../input/foo.png -> "input/foo.png"
|
||||
/.../models/checkpoints/foo.safetensors -> "models/checkpoints/foo.safetensors"
|
||||
|
||||
Returns None for references without a filesystem path or paths outside
|
||||
known asset roots.
|
||||
"""
|
||||
if not file_path:
|
||||
return None
|
||||
try:
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
return "/".join([root_category, *Path(rel_path).parts])
|
||||
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||
@ -156,18 +176,55 @@ def get_asset_category_and_relative_path(
|
||||
)
|
||||
|
||||
|
||||
def get_backend_system_tags_from_path(path: str) -> list[str]:
|
||||
"""Return trusted backend tags derived from current filesystem facts.
|
||||
|
||||
The returned tags are only the backend-generated system tags: ``models``,
|
||||
``model_type:<folder_name>``, ``input``, ``output``, and ``temp``. Model
|
||||
type tags are based on registered folder names, not path components.
|
||||
"""
|
||||
fp_abs = os.path.abspath(path)
|
||||
fp_path = Path(fp_abs)
|
||||
tags: list[str] = []
|
||||
|
||||
def _add(tag: str) -> None:
|
||||
if tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
for role, base in (
|
||||
("input", folder_paths.get_input_directory()),
|
||||
("output", folder_paths.get_output_directory()),
|
||||
("temp", folder_paths.get_temp_directory()),
|
||||
):
|
||||
if fp_path.is_relative_to(os.path.abspath(base)):
|
||||
_add(role)
|
||||
|
||||
model_types: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
for base in bases:
|
||||
if fp_path.is_relative_to(os.path.abspath(base)):
|
||||
model_types.append(folder_name)
|
||||
break
|
||||
|
||||
if model_types:
|
||||
_add("models")
|
||||
for folder_name in model_types:
|
||||
_add(f"model_type:{folder_name}")
|
||||
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, temp, or configured model bases: {path}"
|
||||
)
|
||||
return tags
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] + parent folder names in order
|
||||
- tags: trusted backend classification tags derived from the path
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
root_category, some_path = get_asset_category_and_relative_path(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
return Path(file_path).name, get_backend_system_tags_from_path(file_path)
|
||||
|
||||
@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
|
||||
p = getattr(request, "param", {}) or {}
|
||||
tags: Optional[list[str]] = p.get("tags")
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
# Unique content per test so the seed always creates a fresh asset (201).
|
||||
# Delete is now always a soft delete, so content from a prior test survives
|
||||
|
||||
@ -133,6 +133,66 @@ class TestListReferencesPage:
|
||||
assert total == 1
|
||||
assert refs[0].name == "tagged"
|
||||
|
||||
def test_include_tags_filter_ands_persisted_model_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash-model-tags")
|
||||
checkpoint = _make_reference(session, asset, name="checkpoint")
|
||||
lora = _make_reference(session, asset, name="lora")
|
||||
input_ref = _make_reference(session, asset, name="input")
|
||||
ensure_tags_exist(
|
||||
session,
|
||||
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"],
|
||||
)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=checkpoint.id,
|
||||
tags=["models", "model_type:checkpoints", "unit-tests"],
|
||||
origin="automatic",
|
||||
)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=lora.id,
|
||||
tags=["models", "model_type:loras", "unit-tests"],
|
||||
origin="automatic",
|
||||
)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=input_ref.id,
|
||||
tags=["unit-tests"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session,
|
||||
include_tags=["models", "model_type:checkpoints", "unit-tests"],
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert refs[0].id == checkpoint.id
|
||||
|
||||
def test_include_tags_filter_preserves_model_type_case(self, session: Session):
|
||||
asset = _make_asset(session, "hash-model-case")
|
||||
ref = _make_reference(session, asset, name="llm")
|
||||
ensure_tags_exist(session, ["models", "model_type:LLM"])
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "model_type:LLM"],
|
||||
origin="automatic",
|
||||
)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, include_tags=["models", "model_type:LLM"]
|
||||
)
|
||||
refs_lower, _, total_lower = list_references_page(
|
||||
session, include_tags=["models", "model_type:llm"]
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert refs[0].id == ref.id
|
||||
assert total_lower == 0
|
||||
assert refs_lower == []
|
||||
|
||||
def test_exclude_tags_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, name="keep")
|
||||
|
||||
@ -6,7 +6,11 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
from app.assets.services.path_utils import (
|
||||
get_asset_category_and_relative_path,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -76,6 +80,121 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "models"
|
||||
|
||||
def test_model_path_tags_include_registered_model_type_only(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "subdir" / "model.safetensors"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "checkpoints" not in tags
|
||||
assert "subdir" not in tags
|
||||
|
||||
def test_model_type_preserves_registered_folder_case(self, fake_dirs):
|
||||
llm_dir = fake_dirs["models"].parent / "LLM"
|
||||
llm_dir.mkdir()
|
||||
f = llm_dir / "model.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("LLM", [str(llm_dir)])],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:LLM" in tags
|
||||
assert "model_type:llm" not in tags
|
||||
|
||||
def test_path_components_do_not_create_model_type_tags(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "loras" / "model.safetensors"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "loras" not in tags
|
||||
assert "model_type:loras" not in tags
|
||||
|
||||
def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs):
|
||||
shared_root = fake_dirs["models"].parent / "shared"
|
||||
shared_root.mkdir()
|
||||
f = shared_root / "foo.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(shared_root)]),
|
||||
("loras", [str(shared_root)]),
|
||||
],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "model_type:loras" in tags
|
||||
|
||||
def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs):
|
||||
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
|
||||
output_checkpoints_dir.mkdir()
|
||||
f = output_checkpoints_dir / "saved.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "output" in tags
|
||||
|
||||
def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs):
|
||||
f = fake_dirs["temp"] / "preview.png"
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "temp" in tags
|
||||
assert "output" not in tags
|
||||
assert "preview:true" not in tags
|
||||
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
|
||||
class TestResolveDestinationFromTags:
|
||||
def test_model_upload_rejects_non_writable_registered_folders(self):
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
configs_dir = root_path / "models" / "configs"
|
||||
custom_nodes_dir = root_path / "custom_nodes"
|
||||
for path in (checkpoints_dir, configs_dir, custom_nodes_dir):
|
||||
path.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.folder_names_and_paths = {
|
||||
"checkpoints": ([str(checkpoints_dir)], set()),
|
||||
"configs": ([str(configs_dir)], set()),
|
||||
"custom_nodes": ([str(custom_nodes_dir)], set()),
|
||||
}
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(
|
||||
["models", "model_type:checkpoints"]
|
||||
)
|
||||
assert base_dir == os.path.abspath(checkpoints_dir)
|
||||
assert subdirs == []
|
||||
|
||||
for folder_name in ("configs", "custom_nodes"):
|
||||
with pytest.raises(ValueError, match="unknown model category"):
|
||||
resolve_destination_from_tags(
|
||||
["models", f"model_type:{folder_name}"]
|
||||
)
|
||||
|
||||
@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
# A few selected contract tags should exist.
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
assert "model_type:checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
|
||||
@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
assert "model_type:checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
|
||||
@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
assert "models" in names and "checkpoints" in names
|
||||
assert "models" in names and "model_type:checkpoints" in names
|
||||
|
||||
# Create a short-lived asset under input with a unique custom tag
|
||||
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||
@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
|
||||
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add tags with duplicates and mixed case
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||
# Add tags with duplicates while preserving source case.
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]}
|
||||
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200, b1
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
# stripped, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"NewTag", "BETA"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"]
|
||||
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
g = rg.json()
|
||||
assert rg.status_code == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
assert {"NewTag", "BETA"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
payload_del = {"tags": ["NewTag", "does-not-exist"]}
|
||||
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["removed"]) == {"NewTag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
@ -118,8 +118,44 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
|
||||
g2 = rg2.json()
|
||||
assert rg2.status_code == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
assert "NewTag" not in tags_later
|
||||
assert "BETA" in tags_later # still present
|
||||
|
||||
|
||||
def test_add_system_looking_tags_allowed_as_labels(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
response = http.post(
|
||||
f"{api_base}/api/assets/{aid}/tags",
|
||||
json={
|
||||
"tags": [
|
||||
"models",
|
||||
"model_type:manual",
|
||||
"model:true",
|
||||
"models:foo",
|
||||
"input:true",
|
||||
"output:true",
|
||||
"uploaded:true",
|
||||
"temp:true",
|
||||
"temporary",
|
||||
]
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 200, body
|
||||
assert "models" in body["total_tags"]
|
||||
assert "model_type:manual" in body["total_tags"]
|
||||
assert "model:true" in body["total_tags"]
|
||||
assert "models:foo" in body["total_tags"]
|
||||
assert "input:true" in body["total_tags"]
|
||||
assert "output:true" in body["total_tags"]
|
||||
assert "uploaded:true" in body["total_tags"]
|
||||
assert "temp:true" in body["total_tags"]
|
||||
assert "temporary" in body["total_tags"]
|
||||
|
||||
|
||||
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import pytest
|
||||
|
||||
from app.assets.api.schemas_out import Asset, AssetCreated
|
||||
from helpers import get_asset_filename
|
||||
|
||||
|
||||
def test_asset_created_inherits_hash_field():
|
||||
@ -22,7 +24,7 @@ def test_asset_created_inherits_hash_field():
|
||||
|
||||
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
@ -58,7 +60,7 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
|
||||
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
tags = ["input", "unit-tests"]
|
||||
meta = {}
|
||||
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
@ -69,9 +71,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
|
||||
assert b1["hash"] == h
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"]
|
||||
files = [
|
||||
("hash", (None, h)),
|
||||
("tags", (None, json.dumps(tags))),
|
||||
("tags", (None, json.dumps(hash_only_tags))),
|
||||
("name", (None, "fastpath_copy.safetensors")),
|
||||
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
|
||||
]
|
||||
@ -81,6 +84,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
assert b2["hash"] == h
|
||||
assert "models" in b2["tags"]
|
||||
assert "checkpoints" in b2["tags"]
|
||||
assert "uploaded" not in b2["tags"]
|
||||
assert not any(tag.startswith("model_type:") for tag in b2["tags"])
|
||||
|
||||
|
||||
def test_upload_fastpath_with_known_hash_and_file(
|
||||
@ -88,7 +95,7 @@ def test_upload_fastpath_with_known_hash_and_file(
|
||||
):
|
||||
# Seed
|
||||
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
@ -104,11 +111,47 @@ def test_upload_fastpath_with_known_hash_and_file(
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
assert b2["hash"] == h
|
||||
assert "checkpoints" in b2["tags"]
|
||||
assert "uploaded" not in b2["tags"]
|
||||
assert not any(tag == "model_type:checkpoints" for tag in b2["tags"])
|
||||
|
||||
|
||||
def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
data = b"duplicate-reference-only" * 64
|
||||
seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")}
|
||||
seed_form = {
|
||||
"tags": json.dumps(["input", "unit-tests", "duplicate-seed"]),
|
||||
"name": "duplicate-seed.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120)
|
||||
seed = seed_response.json()
|
||||
assert seed_response.status_code == 201, seed
|
||||
|
||||
duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")}
|
||||
duplicate_form = {
|
||||
"tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]),
|
||||
"name": "duplicate-copy.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
duplicate_response = http.post(
|
||||
api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120
|
||||
)
|
||||
duplicate = duplicate_response.json()
|
||||
|
||||
assert duplicate_response.status_code == 200, duplicate
|
||||
assert duplicate["created_new"] is False
|
||||
assert duplicate["asset_hash"] == seed["asset_hash"]
|
||||
assert "not-a-destination" in duplicate["tags"]
|
||||
assert "uploaded" not in duplicate["tags"]
|
||||
assert "input" not in duplicate["tags"]
|
||||
|
||||
|
||||
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
|
||||
data = [
|
||||
("tags", "models,checkpoints"),
|
||||
("tags", "models,model_type:checkpoints"),
|
||||
("tags", json.dumps(["unit-tests", "alpha"])),
|
||||
("name", "merge.safetensors"),
|
||||
("user_metadata", json.dumps({"u": 1})),
|
||||
@ -124,7 +167,7 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
tags = set(detail["tags"])
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
@ -192,16 +235,55 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_create_from_hash_accepts_arbitrary_system_looking_tags(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["input", "unit-tests", "hash-seed"]),
|
||||
"name": "hash-seed.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
seed = seed_response.json()
|
||||
assert seed_response.status_code == 201, seed
|
||||
|
||||
response = http.post(
|
||||
api_base + "/api/assets/from-hash",
|
||||
json={
|
||||
"hash": seed["asset_hash"],
|
||||
"name": "hash-copy.bin",
|
||||
"tags": [
|
||||
"models",
|
||||
"model:true",
|
||||
"models:foo",
|
||||
"temporary:true",
|
||||
"unit-tests",
|
||||
"hash-copy",
|
||||
],
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 201, body
|
||||
assert "models" in body["tags"]
|
||||
assert "model:true" in body["tags"]
|
||||
assert "models:foo" in body["tags"]
|
||||
assert "temporary:true" in body["tags"]
|
||||
assert "uploaded" not in body["tags"]
|
||||
|
||||
|
||||
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
|
||||
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
|
||||
def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
@ -212,7 +294,7 @@ def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str)
|
||||
|
||||
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
@ -228,7 +310,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str):
|
||||
|
||||
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||
files = [
|
||||
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
|
||||
("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))),
|
||||
("name", (None, "x.safetensors")),
|
||||
]
|
||||
r = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||
@ -237,17 +319,33 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
|
||||
def test_upload_models_unknown_model_type(http: requests.Session, api_base: str):
|
||||
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||
form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert r.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
def test_upload_models_requires_category(http: requests.Session, api_base: str):
|
||||
@pytest.mark.parametrize("model_type", ["configs", "custom_nodes"])
|
||||
def test_upload_models_rejects_non_model_registered_folder(
|
||||
model_type: str, http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]),
|
||||
"name": "not-a-model.py",
|
||||
}
|
||||
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_models_requires_model_type(http: requests.Session, api_base: str):
|
||||
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
@ -256,13 +354,118 @@ def test_upload_models_requires_category(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||
def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str):
|
||||
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
assert r.status_code == 201, body
|
||||
assert ".." in body["tags"]
|
||||
assert "zzz" in body["tags"]
|
||||
assert "models" in body["tags"]
|
||||
assert "model_type:checkpoints" in body["tags"]
|
||||
|
||||
|
||||
def test_multipart_upload_accepts_system_looking_extra_labels(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(
|
||||
[
|
||||
"input",
|
||||
"unit-tests",
|
||||
"model:true",
|
||||
"models:foo",
|
||||
"temporary",
|
||||
"uploaded:true",
|
||||
]
|
||||
),
|
||||
"name": "relaxed-labels.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 201, body
|
||||
assert "input" in body["tags"]
|
||||
assert "model:true" in body["tags"]
|
||||
assert "models:foo" in body["tags"]
|
||||
assert "temporary" in body["tags"]
|
||||
assert "uploaded:true" in body["tags"]
|
||||
|
||||
|
||||
def test_multipart_upload_rejects_ambiguous_destination_roles(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["input", "output", "unit-tests"]),
|
||||
"name": "ambiguous.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_multipart_upload_rejects_multiple_model_types_for_models_destination(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(
|
||||
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"]
|
||||
),
|
||||
"name": "ambiguous-model.safetensors",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tags", "expected_root", "extension"),
|
||||
[
|
||||
(["input", "unit-tests", "upload-location-input"], "input", ".bin"),
|
||||
(["output", "unit-tests", "upload-location-output"], "output", ".bin"),
|
||||
(
|
||||
["models", "model_type:checkpoints", "unit-tests", "upload-location-model"],
|
||||
"models/checkpoints",
|
||||
".safetensors",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multipart_upload_role_selects_write_location(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
tags: list[str],
|
||||
expected_root: str,
|
||||
extension: str,
|
||||
):
|
||||
role = next(tag for tag in tags if tag in {"input", "models", "output"})
|
||||
name = f"{role}-role-upload{extension}"
|
||||
files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 201, body
|
||||
stored_name = get_asset_filename(body["asset_hash"], extension)
|
||||
expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name
|
||||
assert expected_disk_path.exists()
|
||||
assert body["file_path"] == f"{expected_root}/{stored_name}"
|
||||
|
||||
|
||||
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user