diff --git a/app/assets/services/cursor.py b/app/assets/services/cursor.py new file mode 100644 index 000000000..51faa05e5 --- /dev/null +++ b/app/assets/services/cursor.py @@ -0,0 +1,150 @@ +"""Opaque keyset-pagination cursor for /api/assets. + +Wire format mirrors the cloud Go implementation in +`common/pagination/cursor.go` so both runtimes produce byte-identical +cursors for the same `(sort_field, value, id)` triple and the frontend +sees one contract. + +Payload JSON uses short keys to keep the encoded length small: + + {"s": , "v": , "id": } + +Encoding is base64url with no padding. Time values are serialized as Unix +microseconds (UTC) — microsecond precision matches PostgreSQL's +`timestamp` type, so a cursor minted from a stored timestamp compares +back exactly without rounding rows in the same millisecond bucket. +""" +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Iterable, Optional + + +class InvalidCursorError(ValueError): + """Raised on a malformed, oversized, or unsupported-sort-field cursor. + + Map to a 400 response with code ``INVALID_CURSOR`` at the handler. + """ + + +# Wire-format length caps. Cursors are user-controlled, so caps protect the +# decode path from oversized allocations and downstream SQL predicates from +# unbounded strings. Same numbers as cloud/common/pagination/cursor.go. +MAX_ENCODED_CURSOR_LENGTH = 1024 +MAX_CURSOR_VALUE_LENGTH = 256 +MAX_CURSOR_ID_LENGTH = 128 + + +@dataclass(frozen=True) +class CursorPayload: + sort_field: str + value: str + id: str + + +def encode_cursor(sort_field: str, value: str, id: str) -> str: + """Encode a cursor payload as a base64url (no-padding) string.""" + payload = {"s": sort_field, "v": value, "id": id} + raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8") + return base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + + +def encode_cursor_from_time(sort_field: str, t: datetime, id: str) -> str: + """Encode a time-typed cursor at Unix microsecond precision. + + Accepts an aware datetime (any timezone) and normalizes to UTC. Naive + datetimes are rejected so callers can't accidentally encode the local + wall-clock value of a UTC-stored timestamp. + """ + if t.tzinfo is None: + raise ValueError("encode_cursor_from_time requires an aware datetime") + micros = _datetime_to_unix_micros(t.astimezone(timezone.utc)) + return encode_cursor(sort_field, str(micros), id) + + +def decode_cursor(cursor: str, allowed_sort_fields: Iterable[str]) -> CursorPayload: + """Parse an opaque cursor. + + ``allowed_sort_fields`` is the endpoint's accepted sort-field list — a + cursor carrying a field outside this set is rejected so a cursor minted + for one column can't be replayed against another (e.g. a ``created_at`` + timestamp string compared against a ``name`` column). + + Passing no allowed fields rejects every cursor. + """ + if len(cursor) > MAX_ENCODED_CURSOR_LENGTH: + raise InvalidCursorError("cursor exceeds maximum length") + + try: + # urlsafe_b64decode requires correct padding; we strip on encode, so + # restore the trailing '=' pad here. + padding = "=" * (-len(cursor) % 4) + raw = base64.urlsafe_b64decode(cursor + padding) + except (ValueError, base64.binascii.Error) as e: + raise InvalidCursorError(f"encoding: {e}") from e + + try: + decoded = json.loads(raw) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise InvalidCursorError(f"payload: {e}") from e + + if not isinstance(decoded, dict): + raise InvalidCursorError("payload: expected object") + + sort_field = decoded.get("s") + value = decoded.get("v") + id = decoded.get("id") + + if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str): + raise InvalidCursorError("payload: missing or non-string s/v/id") + + if id == "": + raise InvalidCursorError("missing id") + if len(id) > MAX_CURSOR_ID_LENGTH: + raise InvalidCursorError("id exceeds maximum length") + if len(value) > MAX_CURSOR_VALUE_LENGTH: + raise InvalidCursorError("value exceeds maximum length") + + if sort_field not in allowed_sort_fields: + raise InvalidCursorError(f"unsupported sort field {sort_field!r}") + + return CursorPayload(sort_field=sort_field, value=value, id=id) + + +def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime: + """Parse a time-typed cursor value as Unix microseconds, returning UTC.""" + if payload is None: + raise InvalidCursorError("nil cursor payload") + try: + micros = int(payload.value) + except ValueError as e: + raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e + return _unix_micros_to_datetime(micros) + + +def decode_cursor_int(payload: Optional[CursorPayload]) -> int: + """Parse a cursor value as a base-10 integer.""" + if payload is None: + raise InvalidCursorError("nil cursor payload") + try: + return int(payload.value) + except ValueError as e: + raise InvalidCursorError(f"value is not a valid integer: {e}") from e + + +_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) + + +def _datetime_to_unix_micros(t: datetime) -> int: + """Convert an aware UTC datetime to Unix microseconds (integer math).""" + delta = t - _EPOCH + return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds + + +def _unix_micros_to_datetime(micros: int) -> datetime: + """Convert Unix microseconds to a UTC datetime, preserving precision.""" + seconds, micro_remainder = divmod(micros, 1_000_000) + return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder) diff --git a/tests-unit/assets_test/services/test_cursor.py b/tests-unit/assets_test/services/test_cursor.py new file mode 100644 index 000000000..dcf41aaf7 --- /dev/null +++ b/tests-unit/assets_test/services/test_cursor.py @@ -0,0 +1,194 @@ +"""Tests for app.assets.services.cursor. + +The wire format must stay byte-identical with the cloud Go implementation +(common/pagination/cursor.go in Comfy-Org/cloud) so the frontend sees one +contract across runtimes. The byte-identity fixture below mirrors the Go +test cases — any drift here means cloud and OSS minted different cursors +for the same triple, which would break FE pagination across backends. +""" +from __future__ import annotations + +import base64 +from datetime import datetime, timedelta, timezone + +import pytest + +from app.assets.services.cursor import ( + MAX_CURSOR_ID_LENGTH, + MAX_CURSOR_VALUE_LENGTH, + MAX_ENCODED_CURSOR_LENGTH, + CursorPayload, + InvalidCursorError, + decode_cursor, + decode_cursor_int, + decode_cursor_time, + encode_cursor, + encode_cursor_from_time, +) + + +ALLOWED = ("created_at", "updated_at", "name", "size") + + +class TestRoundTrip: + @pytest.mark.parametrize( + "sort_field, value, id", + [ + ("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"), + ("size", "1024", "asset-123"), + ("name", "my-asset.png", "asset-abc"), + ("name", "résumé.txt", "asset-uni"), + ], + ) + def test_encode_decode(self, sort_field, value, id): + encoded = encode_cursor(sort_field, value, id) + assert encoded != "" + payload = decode_cursor(encoded, ALLOWED) + assert payload.sort_field == sort_field + assert payload.value == value + assert payload.id == id + + +class TestTimeCursor: + def test_microsecond_precision_preserved(self): + # Pick a time with non-zero microseconds — encoding at ms would lose the µs. + ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc) + encoded = encode_cursor_from_time("created_at", ts, "id-1") + payload = decode_cursor(encoded, ALLOWED) + # Value must be a microsecond integer string, not a millisecond one. + assert payload.value == "1716209600123456" + decoded = decode_cursor_time(payload) + assert decoded == ts + + def test_decode_returns_utc(self): + payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1") + decoded = decode_cursor_time(payload) + assert decoded.tzinfo == timezone.utc + + def test_naive_datetime_rejected_on_encode(self): + naive = datetime(2024, 5, 20, 12, 0, 0) + with pytest.raises(ValueError): + encode_cursor_from_time("created_at", naive, "id-1") + + def test_non_integer_value_rejected_on_decode(self): + with pytest.raises(InvalidCursorError): + decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1")) + + def test_none_payload_rejected(self): + with pytest.raises(InvalidCursorError): + decode_cursor_time(None) + + def test_non_utc_aware_normalized(self): + # Same instant, different timezone — must encode to the same micros. + utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc) + offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5))) + assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time( + "created_at", offset_ts, "x" + ) + + +class TestIntCursor: + def test_decode_int(self): + assert decode_cursor_int(CursorPayload("size", "1024", "id-1")) == 1024 + + def test_decode_int_rejects_non_int(self): + with pytest.raises(InvalidCursorError): + decode_cursor_int(CursorPayload("size", "abc", "id-1")) + + def test_decode_int_rejects_none(self): + with pytest.raises(InvalidCursorError): + decode_cursor_int(None) + + +class TestInvalidInputs: + def test_oversized_cursor(self): + oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1) + with pytest.raises(InvalidCursorError, match="maximum length"): + decode_cursor(oversized, ALLOWED) + + def test_not_base64(self): + with pytest.raises(InvalidCursorError): + decode_cursor("not base64!!!", ALLOWED) + + def test_not_json(self): + encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError): + decode_cursor(encoded, ALLOWED) + + def test_empty_id(self): + encoded = encode_cursor("created_at", "1", "") + with pytest.raises(InvalidCursorError, match="missing id"): + decode_cursor(encoded, ALLOWED) + + def test_oversized_id(self): + encoded = encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1)) + with pytest.raises(InvalidCursorError, match="id exceeds maximum length"): + decode_cursor(encoded, ALLOWED) + + def test_oversized_value(self): + encoded = encode_cursor("created_at", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1") + with pytest.raises(InvalidCursorError, match="value exceeds maximum length"): + decode_cursor(encoded, ALLOWED) + + def test_unsupported_sort_field(self): + encoded = encode_cursor("execution_time", "1", "id-1") + with pytest.raises(InvalidCursorError, match="unsupported sort field"): + decode_cursor(encoded, ALLOWED) + + def test_no_allowed_fields_rejects_everything(self): + encoded = encode_cursor("created_at", "1", "id-1") + with pytest.raises(InvalidCursorError): + decode_cursor(encoded, ()) + + def test_non_dict_payload_rejected(self): + encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="expected object"): + decode_cursor(encoded, ALLOWED) + + +class TestEncodeAtCapsFits: + def test_max_field_lengths_fit_wire_cap(self): + # Worst-case payload: value and id at their per-field caps, with a long + # sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH + # so the wire cap cannot reject a cursor the encoder mints at the per-field caps. + value = "v" * MAX_CURSOR_VALUE_LENGTH + id = "i" * MAX_CURSOR_ID_LENGTH + sort_field = "very_long_sort_field_name" + + encoded = encode_cursor(sort_field, value, id) + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, (sort_field,)) + assert payload.value == value + assert payload.id == id + + +class TestByteIdentityWithCloud: + """Lock the wire format against drift from cloud's Go implementation. + + Drop these fixtures from common/pagination/cursor_test.go in cloud — they + encode to specific base64url strings, and any drift on either side breaks + cross-runtime FE pagination. + + To regenerate, run cloud's test harness with these inputs and capture the + output of EncodeCursor, then paste below. + """ + + @pytest.mark.parametrize( + "sort_field, value, id, expected_encoded", + [ + # Generated from cloud encode_cursor: json.Marshal yields keys in + # insertion order for our struct (s, v, id), then RawURLEncoding base64. + ("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", + "eyJzIjoiY3JlYXRlZF9hdCIsInYiOiIxNzE2MjAwMDAwMDAwMDAwIiwiaWQiOiJhMWIyYzNkNC1lNWY2LTdhODktYjBjMS1kMmUzZjRhNWI2YzcifQ"), + ("size", "1024", "asset-123", + "eyJzIjoic2l6ZSIsInYiOiIxMDI0IiwiaWQiOiJhc3NldC0xMjMifQ"), + ("name", "my-asset.png", "asset-abc", + "eyJzIjoibmFtZSIsInYiOiJteS1hc3NldC5wbmciLCJpZCI6ImFzc2V0LWFiYyJ9"), + ], + ) + def test_python_matches_cloud_wire_bytes(self, sort_field, value, id, expected_encoded): + actual = encode_cursor(sort_field, value, id) + assert actual == expected_encoded, ( + f"Python cursor diverged from cloud Go wire format. " + f"Got: {actual!r}, expected: {expected_encoded!r}" + )