mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-20 23:09:27 +08:00
Compare commits
12 Commits
7ab346fc7b
...
00940fb24e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00940fb24e | ||
|
|
7ff001d7c8 | ||
|
|
19ba85bb2e | ||
|
|
3ffc49aa0e | ||
|
|
72e3f6081c | ||
|
|
36f9a6fdef | ||
|
|
a0d1238829 | ||
|
|
7ec7b6ffe9 | ||
|
|
6887165a9d | ||
|
|
1688a5e262 | ||
|
|
cc4d711eb1 | ||
|
|
626b082838 |
@ -327,7 +327,12 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
# Preserve insertion order so the structural first tag (the root
|
||||
# category like "models") stays in position 0 and the path-derived
|
||||
# sub-path tag stays in position 1, matching cloud's behavior.
|
||||
# tag_name is a deterministic tiebreaker when multiple tags share
|
||||
# an added_at (same-batch insert via set_reference_tags).
|
||||
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@ -355,7 +360,8 @@ def fetch_reference_asset_and_tags(
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetReference.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
# See list_references_page for the rationale behind ordering by added_at.
|
||||
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -77,7 +78,13 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
# Match the response-path ordering used by
|
||||
# list_references_page / fetch_reference_asset_and_tags so
|
||||
# upload responses and subsequent GETs agree on tag order.
|
||||
.order_by(
|
||||
AssetReferenceTag.added_at.asc(),
|
||||
AssetReferenceTag.tag_name.asc(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@ -98,15 +105,21 @@ def set_reference_tags(
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
|
||||
# added_at preserves input order. Per-tag get_utc_now() calls can
|
||||
# collide at microsecond resolution on fast machines, dropping the
|
||||
# query to the tag_name alphabetical tiebreaker — same fix as in
|
||||
# batch_insert_seed_assets.
|
||||
base_ts = get_utc_now()
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
)
|
||||
for t in to_add
|
||||
for i, t in enumerate(to_add)
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
@ -146,10 +159,16 @@ def add_tags_to_reference(
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
# Preserve the caller's insertion order rather than alphabetizing —
|
||||
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
|
||||
# preserves insertion order if "the order we insert in" actually matches
|
||||
# the caller's intent.
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
to_add = [t for t in norm if t not in current]
|
||||
|
||||
if to_add:
|
||||
# See set_reference_tags for the rationale behind the per-tag stagger.
|
||||
base_ts = get_utc_now()
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
@ -158,9 +177,9 @@ def add_tags_to_reference(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
)
|
||||
for t in to_add
|
||||
for i, t in enumerate(to_add)
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@ -233,13 +233,19 @@ def batch_insert_seed_assets(
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
for tag in ref_data["tags"]:
|
||||
# Stagger added_at by microsecond per tag within a reference so
|
||||
# the retrieval ORDER BY added_at preserves the input list order
|
||||
# (the path-derived root category stays at position 0). Without
|
||||
# this, every tag in a bulk-insert batch shares current_time and
|
||||
# the tag_name tiebreaker sorts them alphabetically — putting the
|
||||
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
|
||||
for tag_idx, tag in enumerate(ref_data["tags"]):
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time,
|
||||
"added_at": current_time + timedelta(microseconds=tag_idx),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -26,27 +26,51 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
|
||||
|
||||
Accepts both the legacy one-tag-per-directory shape
|
||||
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
|
||||
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
|
||||
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
|
||||
mix the two within a single call (e.g.
|
||||
``["models", "diffusers", "Kolors/text_encoder"]``) are also
|
||||
accepted: each entry after ``tags[0]`` is split on ``/`` and
|
||||
concatenated, so the two shapes — and any mix of them — resolve to
|
||||
the same destination. The same safety checks are applied to each
|
||||
component after expansion.
|
||||
"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
|
||||
# Expand any slash-joined entries into individual path components so
|
||||
# the rest of the function can treat both tag shapes uniformly. Each
|
||||
# component is also stripped, so " a / b " behaves like ["a", "b"].
|
||||
expanded: list[str] = []
|
||||
for t in tags[1:]:
|
||||
for part in str(t).split("/"):
|
||||
part = part.strip()
|
||||
if part:
|
||||
expanded.append(part)
|
||||
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
if not expanded:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
category = expanded[0]
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
bases = folder_paths.folder_names_and_paths[category][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
raise ValueError(f"unknown model category '{category}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
raise ValueError(f"no base path configured for category '{category}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
raw_subdirs = expanded[1:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
raw_subdirs = expanded
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
raw_subdirs = expanded
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
@ -166,11 +190,14 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
consumers can use ``tags[1]`` as a stable category identifier that
|
||||
survives nested directory layouts (e.g. diffusers components).
|
||||
|
||||
Case is preserved on the subpath so that consumers can look up
|
||||
providers keyed on the original-case path (e.g.
|
||||
``"diffusers/Kolors/text_encoder"``). The root category is always
|
||||
lowercase by construction in
|
||||
:func:`get_asset_category_and_relative_path`.
|
||||
The subpath is lowercased to match the canonicalization applied by
|
||||
:func:`ensure_tags_exist`; without that, the
|
||||
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
|
||||
would fail for any path containing uppercase letters. The root
|
||||
category is lowercase by construction in
|
||||
:func:`get_asset_category_and_relative_path`, so no separate cast
|
||||
is applied here. Consumers that need to look up providers keyed on
|
||||
original-case paths should normalize their lookup key to lowercase.
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
@ -182,5 +209,5 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
]
|
||||
tags = [root_category]
|
||||
if parent_parts:
|
||||
tags.append("/".join(parent_parts))
|
||||
tags.append("/".join(parent_parts).lower())
|
||||
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))
|
||||
|
||||
@ -260,7 +260,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# NOTE: offloadable=False is a legacy mode and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
if input is not None:
|
||||
|
||||
@ -77,7 +77,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 32})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
@ -1,10 +1,41 @@
|
||||
import re
|
||||
import json
|
||||
import string
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class StringFormat(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
autogrow = io.Autogrow.TemplateNames(
|
||||
input=io.AnyType.Input("value"),
|
||||
names=list(string.ascii_lowercase),
|
||||
min=0,
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="StringFormat",
|
||||
display_name="Format Text",
|
||||
category="text",
|
||||
search_aliases=["string", "format"],
|
||||
description="Same as Python's string format method. Supports all of Python's format options and features.",
|
||||
inputs=[
|
||||
io.Autogrow.Input("values", template=autogrow),
|
||||
io.String.Input("f_string", default="{a}", multiline=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, values: io.Autogrow.Type, f_string: str
|
||||
) -> io.NodeOutput:
|
||||
return io.NodeOutput(f_string.format(**values))
|
||||
|
||||
|
||||
class StringConcatenate(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -413,6 +444,7 @@ class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
StringFormat,
|
||||
StringConcatenate,
|
||||
StringSubstring,
|
||||
StringLength,
|
||||
|
||||
14
openapi.yaml
14
openapi.yaml
@ -4160,6 +4160,10 @@ paths:
|
||||
name:
|
||||
type: string
|
||||
description: Display name for the API key
|
||||
description:
|
||||
type: string
|
||||
description: User-provided description of the key's purpose
|
||||
maxLength: 5000
|
||||
responses:
|
||||
"201":
|
||||
description: API key created
|
||||
@ -7677,11 +7681,16 @@ components:
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- description
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
maxLength: 5000
|
||||
description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create.
|
||||
prefix:
|
||||
type: string
|
||||
description: First few characters of the key for identification
|
||||
@ -7702,12 +7711,17 @@ components:
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- description
|
||||
- key
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
maxLength: 5000
|
||||
description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create.
|
||||
key:
|
||||
type: string
|
||||
description: Full API key value (only returned on creation)
|
||||
|
||||
@ -21,6 +21,7 @@ from app.assets.database.queries import (
|
||||
get_reference_ids_by_ids,
|
||||
ensure_tags_exist,
|
||||
add_tags_to_reference,
|
||||
set_reference_tags,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
@ -159,6 +160,105 @@ class TestListReferencesPage:
|
||||
assert refs[0].name == "large"
|
||||
|
||||
|
||||
class TestTagRetrievalOrder:
|
||||
"""End-to-end check: tags written through the public write paths come
|
||||
back from the public read paths in insertion order rather than the
|
||||
composite-PK alphabetical order SQLite would otherwise impose.
|
||||
|
||||
Each test deliberately picks tag names that would sort differently
|
||||
under alphabetical vs insertion order, so an alphabetical regression
|
||||
fails loudly.
|
||||
"""
|
||||
|
||||
def _make_ref(self, session: Session) -> AssetReference:
|
||||
asset = _make_asset(session, "h1")
|
||||
return _make_reference(session, asset, name="x.bin")
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# "checkpoints" < "models" alphabetically; if added_at stagger
|
||||
# works, list_references_page returns insertion order.
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# Subpath tag sorts before "models" alphabetically.
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "diffusers/kolors/text_encoder"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = fetch_reference_asset_and_tags(session, ref.id)
|
||||
assert result is not None
|
||||
_, _, tags = result
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder"]
|
||||
|
||||
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# "aaa-..." sorts before both path tags alphabetically. If added_at
|
||||
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Three user tags inserted in non-alphabetical input order. Per-tag
|
||||
# microsecond stagger should preserve at least the "user batch is
|
||||
# after path tags" property; within the user batch insertion order
|
||||
# is also preserved.
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["zzz-z", "favorite", "experiment-q4"],
|
||||
origin="manual",
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
tags = tag_map[ref.id]
|
||||
assert tags[0:2] == ["models", "checkpoints"]
|
||||
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
|
||||
|
||||
def test_remove_then_add_does_not_disrupt_path_tag_positions(
|
||||
self, session: Session
|
||||
):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "loras/my/custom/path"],
|
||||
)
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
from app.assets.database.queries import remove_tags_from_reference
|
||||
|
||||
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "loras/my/custom/path", "second-tag"]
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_reference_asset_and_tags(session, "nonexistent")
|
||||
|
||||
@ -9,6 +9,7 @@ import pytest
|
||||
from app.assets.services.path_utils import (
|
||||
get_asset_category_and_relative_path,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
|
||||
|
||||
@ -159,7 +160,13 @@ class TestGetNameAndTagsFromAssetPath:
|
||||
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""Diffusers components live in nested directories — the full subpath
|
||||
must collapse into one tag so consumers can look up the model category
|
||||
via tags[1] regardless of nesting depth."""
|
||||
via tags[1] regardless of nesting depth.
|
||||
|
||||
The subpath is lowercased to match the canonicalization
|
||||
:func:`ensure_tags_exist` applies on the write side; without that,
|
||||
the asset_reference_tags.tag_name FK to tags.name would fail for
|
||||
any path containing uppercase letters.
|
||||
"""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["diffusers"]
|
||||
/ "Kolors"
|
||||
@ -170,7 +177,7 @@ class TestGetNameAndTagsFromAssetPath:
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "model.safetensors"
|
||||
assert tags == ["models", "diffusers/Kolors/text_encoder"]
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder"]
|
||||
|
||||
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""User-created subdirectories under a model bucket also collapse to a
|
||||
@ -187,3 +194,94 @@ class TestGetNameAndTagsFromAssetPath:
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "v0001.safetensors"
|
||||
assert tags == ["models", "loras/my/custom/path"]
|
||||
|
||||
|
||||
class TestResolveDestinationFromTags:
|
||||
"""resolve_destination_from_tags must accept both the legacy
|
||||
one-tag-per-directory shape and the new slash-joined shape so that an
|
||||
upload using the tags it just read back from /api/assets round-trips
|
||||
to the right on-disk destination.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_dirs(self):
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
|
||||
d.mkdir(parents=True)
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.folder_names_and_paths = {
|
||||
"checkpoints": ([str(checkpoints_dir)], None),
|
||||
"diffusers": ([str(diffusers_dir)], None),
|
||||
"loras": ([str(loras_dir)], None),
|
||||
}
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
def test_models_flat_category(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
|
||||
assert base == str(resolve_dirs["checkpoints"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_models_slash_joined_new_shape(self, resolve_dirs):
|
||||
# The shape get_name_and_tags_from_asset_path now emits.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
|
||||
# The legacy shape must still resolve identically.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers", "kolors", "text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_loras_slash_joined(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "loras/my/custom/path"]
|
||||
)
|
||||
assert base == str(resolve_dirs["loras"])
|
||||
assert subdirs == ["my", "custom", "path"]
|
||||
|
||||
def test_input_no_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_input_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == ["portraits", "2026"]
|
||||
|
||||
def test_output_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
|
||||
assert base == str(resolve_dirs["output"])
|
||||
assert subdirs == ["runs", "abc"]
|
||||
|
||||
def test_unknown_category_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="unknown model category"):
|
||||
resolve_destination_from_tags(["models", "not_a_real_category"])
|
||||
|
||||
def test_unknown_category_via_slash_joined(self, resolve_dirs):
|
||||
# First segment of a slash-joined tag must still match a registered category.
|
||||
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
|
||||
resolve_destination_from_tags(["models", "bogus/sub/path"])
|
||||
|
||||
def test_traversal_in_subdir_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="invalid path component"):
|
||||
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])
|
||||
|
||||
89
tests-unit/assets_test/test_user_tag_http_smoke.py
Normal file
89
tests-unit/assets_test/test_user_tag_http_smoke.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
|
||||
land after path tags when read back via GET /api/assets.
|
||||
|
||||
Exercises the full route handler -> service -> query path that the unit
|
||||
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
|
||||
the service layer.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smoke_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a single asset into models/checkpoints/unit-tests/smoke
|
||||
and delete it on teardown."""
|
||||
name = "smoke_user_tag.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "smoke"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def _fetch_asset_tags(http, api_base, ref_id):
|
||||
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
|
||||
assert r.status_code == 200, r.text
|
||||
return r.json()["tags"]
|
||||
|
||||
|
||||
def test_user_tag_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags should already be at the front in upload order.
|
||||
assert initial_tags[:2] == ["models", "checkpoints"]
|
||||
|
||||
# Add a user tag that would jump to position 0 under alphabetical sort.
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["aaa-user-tag"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags must still be at the front; user tag goes to the end.
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
assert "aaa-user-tag" in tags_after
|
||||
assert tags_after[-1] == "aaa-user-tag"
|
||||
|
||||
|
||||
def test_user_tag_batch_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
# Add three user tags in a single request, in non-alphabetical input
|
||||
# order. They should all land after the path tags (microsecond stagger
|
||||
# in set_reference_tags / add_tags_to_reference is what makes this
|
||||
# work — without it, "aaa" would jump to position 0).
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
|
||||
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
|
||||
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("models")
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
|
||||
Loading…
Reference in New Issue
Block a user