mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 03:17:23 +08:00
- Mint next_cursor on every cursor-supported sort, not only when 'after' was supplied. A first request (no 'after') previously returned next_cursor=None, leaving cursor mode unreachable from a clean start. - Over-fetch limit+1 so an exactly-full terminal page doesn't mint a spurious cursor pointing at a phantom next page. - Map crafted out-of-range microsecond cursors (OverflowError / OSError in datetime construction) to 400 INVALID_CURSOR instead of leaking 500. - Bump MAX_CURSOR_VALUE_LENGTH 256 -> 512 to match the AssetReference name column max; without this, a long-named asset minted a cursor the same server then refused on the next request. Cross-runtime byte identity with cloud is unaffected because no cloud cursor ever carries a value > 256 (cloud schema doesn't permit it). - Return None from _encode_next_cursor when the boundary row carries a NULL sort value (e.g. an Asset without size_bytes backfilled), instead of silently encoding 0 and mis-positioning the keyset. - Fix schemas_in.py comment so it matches actual handler behavior (last_access_time + 'after' raises 400, does not fall back). - Add AssetsApiError schema + 400 response to GET /api/assets in openapi.yaml so generated clients know the INVALID_CURSOR envelope. - Extend integration coverage: first-page mint, exact-multiple terminal page, cursor walks for created_at/updated_at/size sorts, datetime overflow surfaces as 400 not 500. - Add unit coverage for datetime overflow and 512-char round-trip.
229 lines
9.0 KiB
Python
229 lines
9.0 KiB
Python
"""Tests for app.assets.services.cursor.
|
|
|
|
The wire format must stay byte-identical with the cloud Go implementation
|
|
(common/pagination/cursor.go in Comfy-Org/cloud) so the frontend sees one
|
|
contract across runtimes. The byte-identity fixture below mirrors the Go
|
|
test cases — any drift here means cloud and OSS minted different cursors
|
|
for the same triple, which would break FE pagination across backends.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
|
|
from app.assets.services.cursor import (
|
|
MAX_CURSOR_ID_LENGTH,
|
|
MAX_CURSOR_VALUE_LENGTH,
|
|
MAX_ENCODED_CURSOR_LENGTH,
|
|
CursorPayload,
|
|
InvalidCursorError,
|
|
decode_cursor,
|
|
decode_cursor_int,
|
|
decode_cursor_time,
|
|
encode_cursor,
|
|
encode_cursor_from_time,
|
|
)
|
|
|
|
|
|
ALLOWED = ("created_at", "updated_at", "name", "size")
|
|
|
|
|
|
class TestRoundTrip:
|
|
@pytest.mark.parametrize(
|
|
"sort_field, value, id",
|
|
[
|
|
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
|
|
("size", "1024", "asset-123"),
|
|
("name", "my-asset.png", "asset-abc"),
|
|
("name", "résumé.txt", "asset-uni"),
|
|
],
|
|
)
|
|
def test_encode_decode(self, sort_field, value, id):
|
|
encoded = encode_cursor(sort_field, value, id)
|
|
assert encoded != ""
|
|
payload = decode_cursor(encoded, ALLOWED)
|
|
assert payload.sort_field == sort_field
|
|
assert payload.value == value
|
|
assert payload.id == id
|
|
|
|
|
|
class TestTimeCursor:
|
|
def test_microsecond_precision_preserved(self):
|
|
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
|
|
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
|
|
encoded = encode_cursor_from_time("created_at", ts, "id-1")
|
|
payload = decode_cursor(encoded, ALLOWED)
|
|
# Value must be a microsecond integer string, not a millisecond one.
|
|
assert payload.value == "1716209600123456"
|
|
decoded = decode_cursor_time(payload)
|
|
assert decoded == ts
|
|
|
|
def test_decode_returns_utc(self):
|
|
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1")
|
|
decoded = decode_cursor_time(payload)
|
|
assert decoded.tzinfo == timezone.utc
|
|
|
|
def test_naive_datetime_rejected_on_encode(self):
|
|
naive = datetime(2024, 5, 20, 12, 0, 0)
|
|
with pytest.raises(ValueError):
|
|
encode_cursor_from_time("created_at", naive, "id-1")
|
|
|
|
def test_non_integer_value_rejected_on_decode(self):
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1"))
|
|
|
|
def test_none_payload_rejected(self):
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor_time(None)
|
|
|
|
def test_non_utc_aware_normalized(self):
|
|
# Same instant, different timezone — must encode to the same micros.
|
|
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
|
|
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
|
|
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
|
|
"created_at", offset_ts, "x"
|
|
)
|
|
|
|
|
|
class TestIntCursor:
|
|
def test_decode_int(self):
|
|
assert decode_cursor_int(CursorPayload("size", "1024", "id-1")) == 1024
|
|
|
|
def test_decode_int_rejects_non_int(self):
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor_int(CursorPayload("size", "abc", "id-1"))
|
|
|
|
def test_decode_int_rejects_none(self):
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor_int(None)
|
|
|
|
|
|
class TestInvalidInputs:
|
|
def test_oversized_cursor(self):
|
|
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
|
|
with pytest.raises(InvalidCursorError, match="maximum length"):
|
|
decode_cursor(oversized, ALLOWED)
|
|
|
|
def test_not_base64(self):
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor("not base64!!!", ALLOWED)
|
|
|
|
def test_not_json(self):
|
|
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor(encoded, ALLOWED)
|
|
|
|
def test_empty_id(self):
|
|
encoded = encode_cursor("created_at", "1", "")
|
|
with pytest.raises(InvalidCursorError, match="missing id"):
|
|
decode_cursor(encoded, ALLOWED)
|
|
|
|
def test_oversized_id(self):
|
|
encoded = encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
|
|
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
|
decode_cursor(encoded, ALLOWED)
|
|
|
|
def test_oversized_value(self):
|
|
encoded = encode_cursor("created_at", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
|
|
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
|
decode_cursor(encoded, ALLOWED)
|
|
|
|
def test_unsupported_sort_field(self):
|
|
encoded = encode_cursor("execution_time", "1", "id-1")
|
|
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
|
|
decode_cursor(encoded, ALLOWED)
|
|
|
|
def test_no_allowed_fields_rejects_everything(self):
|
|
encoded = encode_cursor("created_at", "1", "id-1")
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor(encoded, ())
|
|
|
|
def test_non_dict_payload_rejected(self):
|
|
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
|
|
with pytest.raises(InvalidCursorError, match="expected object"):
|
|
decode_cursor(encoded, ALLOWED)
|
|
|
|
|
|
class TestEncodeAtCapsFits:
|
|
def test_max_field_lengths_fit_wire_cap(self):
|
|
# Worst-case payload: value and id at their per-field caps, with a long
|
|
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
|
|
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
|
|
value = "v" * MAX_CURSOR_VALUE_LENGTH
|
|
id = "i" * MAX_CURSOR_ID_LENGTH
|
|
sort_field = "very_long_sort_field_name"
|
|
|
|
encoded = encode_cursor(sort_field, value, id)
|
|
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
|
payload = decode_cursor(encoded, (sort_field,))
|
|
assert payload.value == value
|
|
assert payload.id == id
|
|
|
|
|
|
class TestDatetimeOverflow:
|
|
"""Crafted cursors with extreme micros must map to InvalidCursorError,
|
|
not OverflowError/OSError leaking as 500.
|
|
"""
|
|
|
|
@pytest.mark.parametrize(
|
|
"micros_str",
|
|
[
|
|
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
|
|
"-999999999999999999999", # symmetric negative — pre-epoch overflow
|
|
],
|
|
)
|
|
def test_out_of_range_micros_rejected(self, micros_str):
|
|
encoded = encode_cursor("created_at", micros_str, "asset-x")
|
|
payload = decode_cursor(encoded, ALLOWED)
|
|
with pytest.raises(InvalidCursorError):
|
|
decode_cursor_time(payload)
|
|
|
|
|
|
class TestEncoderDecoderSymmetry:
|
|
"""The encoder must reject inputs the decoder rejects, or the same server
|
|
will mint a cursor it then 400s on the next request.
|
|
"""
|
|
|
|
def test_long_name_within_cap_round_trips(self):
|
|
"""OSS assets allow names up to 512 chars (`String(512)`); cursor must
|
|
handle that. Cloud's lower cap is acceptable on its side because the
|
|
cloud schema doesn't permit names that long."""
|
|
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
|
|
encoded = encode_cursor("name", long_name, "asset-x")
|
|
payload = decode_cursor(encoded, ALLOWED)
|
|
assert payload.value == long_name
|
|
|
|
|
|
class TestByteIdentityWithCloud:
|
|
"""Lock the wire format against drift from cloud's Go implementation.
|
|
|
|
Drop these fixtures from common/pagination/cursor_test.go in cloud — they
|
|
encode to specific base64url strings, and any drift on either side breaks
|
|
cross-runtime FE pagination.
|
|
|
|
To regenerate, run cloud's test harness with these inputs and capture the
|
|
output of EncodeCursor, then paste below.
|
|
"""
|
|
|
|
@pytest.mark.parametrize(
|
|
"sort_field, value, id, expected_encoded",
|
|
[
|
|
# Generated from cloud encode_cursor: json.Marshal yields keys in
|
|
# insertion order for our struct (s, v, id), then RawURLEncoding base64.
|
|
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7",
|
|
"eyJzIjoiY3JlYXRlZF9hdCIsInYiOiIxNzE2MjAwMDAwMDAwMDAwIiwiaWQiOiJhMWIyYzNkNC1lNWY2LTdhODktYjBjMS1kMmUzZjRhNWI2YzcifQ"),
|
|
("size", "1024", "asset-123",
|
|
"eyJzIjoic2l6ZSIsInYiOiIxMDI0IiwiaWQiOiJhc3NldC0xMjMifQ"),
|
|
("name", "my-asset.png", "asset-abc",
|
|
"eyJzIjoibmFtZSIsInYiOiJteS1hc3NldC5wbmciLCJpZCI6ImFzc2V0LWFiYyJ9"),
|
|
],
|
|
)
|
|
def test_python_matches_cloud_wire_bytes(self, sort_field, value, id, expected_encoded):
|
|
actual = encode_cursor(sort_field, value, id)
|
|
assert actual == expected_encoded, (
|
|
f"Python cursor diverged from cloud Go wire format. "
|
|
f"Got: {actual!r}, expected: {expected_encoded!r}"
|
|
)
|