From accd0a414d599d84ad6f71af279a22bf6ab2150a Mon Sep 17 00:00:00 2001 From: jarrel Date: Thu, 17 Aug 2023 09:57:31 -0700 Subject: [PATCH] Refactor cursors --- publicapi/collection.go | 17 +- publicapi/feed.go | 202 +++----- publicapi/interaction.go | 8 +- publicapi/pagination.go | 859 +++++++++++++++++++---------------- publicapi/pagination_test.go | 67 ++- publicapi/publicapi.go | 10 +- publicapi/user.go | 22 +- 7 files changed, 607 insertions(+), 578 deletions(-) diff --git a/publicapi/collection.go b/publicapi/collection.go index d0eb1d193..b512579fa 100644 --- a/publicapi/collection.go +++ b/publicapi/collection.go @@ -117,15 +117,16 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai } var collectionIDs []persist.DBID + var cursor positionCursor var paginator positionPaginator // If a cursor is provided, we can skip querying the cache if before != nil { - if _, collectionIDs, err = paginator.decodeCursor(*before); err != nil { + if err = cursor.Unpack(*before); err != nil { return nil, pageInfo, err } } else if after != nil { - if _, collectionIDs, err = paginator.decodeCursor(*after); err != nil { + if err = cursor.Unpack(*after); err != nil { return nil, pageInfo, err } } else { @@ -143,9 +144,7 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai } paginator.QueryFunc = func(params positionPagingParams) ([]any, error) { - cIDs, _ := util.Map(collectionIDs, func(id persist.DBID) (string, error) { - return id.String(), nil - }) + cIDs := util.MapWithoutError(collectionIDs, func(id persist.DBID) string { return id.String() }) c, err := api.queries.GetVisibleCollectionsByIDsPaginate(ctx, db.GetVisibleCollectionsByIDsPaginateParams{ CollectionIds: cIDs, CurBeforePos: params.CursorBeforePos, @@ -153,9 +152,7 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai PagingForward: params.PagingForward, Limit: params.Limit, }) - a, _ := util.Map(c, func(c db.Collection) (any, error) { - return c, nil - }) + a := util.MapWithoutError(c, func(c db.Collection) any { return c }) return a, err } @@ -164,8 +161,8 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai posLookup[id] = i + 1 // Postgres uses 1-based indexing } - paginator.CursorFunc = func(node any) (int, []persist.DBID, error) { - return posLookup[node.(db.Collection).ID], collectionIDs, nil + paginator.CursorFunc = func(node any) (int64, []persist.DBID, error) { + return int64(posLookup[node.(db.Collection).ID]), collectionIDs, nil } // The collections are sorted by ascending rank so we need to switch the cursor positions diff --git a/publicapi/feed.go b/publicapi/feed.go index 1ddab995e..30f0d3549 100644 --- a/publicapi/feed.go +++ b/publicapi/feed.go @@ -352,14 +352,11 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri var ( err error + cursor feedPositionCursor paginator feedPaginator - entityTypes []persist.FeedEntityType - entityIDs []persist.DBID entityIDToPos = make(map[persist.DBID]int) ) - hasCursors := before != nil || after != nil - now := time.Now() // Include posts for admins always during the soft launch @@ -367,7 +364,15 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri includePosts = shouldShowPosts(ctx) } - if !hasCursors { + if before != nil { + if err = cursor.Unpack(*before); err != nil { + return nil, PageInfo{}, err + } + } else if after != nil { + if err = cursor.Unpack(*after); err != nil { + return nil, PageInfo{}, err + } + } else { calcFunc := func(ctx context.Context) ([]persist.FeedEntityType, []persist.DBID, error) { trendData, err := fetchFeedEntities(ctx, api.queries, feedParams{ IncludePosts: includePosts, @@ -382,8 +387,8 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri return timeFactor(e.CreatedAt, now) * engagementFactor(int(e.Interactions)) }) - entityTypes = make([]persist.FeedEntityType, len(scored)) - entityIDs = make([]persist.DBID, len(scored)) + entityTypes := make([]persist.FeedEntityType, len(scored)) + entityIDs := make([]persist.DBID, len(scored)) for i, e := range scored { idx := len(scored) - i - 1 @@ -396,45 +401,41 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri l := newFeedCache(api.cache, includePosts, calcFunc) - entityTypes, entityIDs, err = l.Load(ctx) + cursor.EntityTypes, cursor.EntityIDs, err = l.Load(ctx) if err != nil { return nil, PageInfo{}, err } } queryFunc := func(params feedPagingParams) ([]any, error) { - if hasCursors { - entityTypes = params.EntityTypes - entityIDs = params.EntityIDs - } - for i, id := range entityIDs { + for i, id := range cursor.EntityIDs { entityIDToPos[id] = i } // Filter slices in place if !includePosts { idx := 0 - for i := range entityTypes { - if entityTypes[i] != persist.PostTypeTag { - entityTypes[idx] = entityTypes[i] - entityIDs[idx] = entityIDs[i] + for i := range cursor.EntityTypes { + if cursor.EntityTypes[i] != persist.PostTypeTag { + cursor.EntityTypes[idx] = cursor.EntityTypes[i] + cursor.EntityIDs[idx] = cursor.EntityIDs[i] idx++ } } - entityTypes = entityTypes[:idx] - entityIDs = entityIDs[:idx] + cursor.EntityTypes = cursor.EntityTypes[:idx] + cursor.EntityIDs = cursor.EntityIDs[:idx] } - return loadFeedEntities(ctx, api.loaders, entityTypes, entityIDs) + return loadFeedEntities(ctx, api.loaders, cursor.EntityTypes, cursor.EntityIDs) } - cursorFunc := func(node any) (int, []persist.FeedEntityType, []persist.DBID, error) { + cursorFunc := func(node any) (int64, []persist.FeedEntityType, []persist.DBID, error) { _, id, err := feedCursor(node) pos, ok := entityIDToPos[id] if !ok { panic(fmt.Sprintf("could not find position for id=%s", id)) } - return pos, entityTypes, entityIDs, err + return int64(pos), cursor.EntityTypes, cursor.EntityIDs, err } paginator.QueryFunc = queryFunc @@ -457,13 +458,10 @@ func (api FeedAPI) CuratedFeed(ctx context.Context, before, after *string, first var ( paginator feedPaginator - entityTypes []persist.FeedEntityType - entityIDs []persist.DBID + cursor feedPositionCursor entityIDToPos = make(map[persist.DBID]int) ) - hasCursors := before != nil || after != nil - now := time.Now() // Include posts for admins always during the soft launch @@ -471,13 +469,22 @@ func (api FeedAPI) CuratedFeed(ctx context.Context, before, after *string, first includePosts = shouldShowPosts(ctx) } - if !hasCursors { - trendData, err := fetchFeedEntities(ctx, api.queries, feedParams{ - IncludePosts: includePosts, - IncludeEvents: !includePosts, - ExcludeUserID: userID, - ExcludeActions: []persist.Action{persist.ActionUserCreated, persist.ActionUserFollowedUsers}, - FetchFrom: time.Duration(7 * 24 * time.Hour), + if before != nil { + if err := cursor.Unpack(*before); err != nil { + return nil, PageInfo{}, err + } + } else if after != nil { + if err := cursor.Unpack(*after); err != nil { + return nil, PageInfo{}, err + } + } else { + trendData, err := api.queries.GetFeedEntityScores(ctx, db.GetFeedEntityScoresParams{ + WindowEnd: time.Now().Add(-time.Duration(7 * 24 * time.Hour)), + PostEntityType: int32(persist.PostTypeTag), + ExcludedFeedActions: []string{string(persist.ActionUserCreated), string(persist.ActionUserFollowedUsers)}, + IncludeViewer: false, + ViewerID: userID, + IncludePosts: includePosts, }) if err != nil { return nil, PageInfo{}, err @@ -538,49 +545,45 @@ func (api FeedAPI) CuratedFeed(ctx context.Context, before, after *string, first recommend.Shuffle(interleaved, 8) - entityTypes = make([]persist.FeedEntityType, len(interleaved)) - entityIDs = make([]persist.DBID, len(interleaved)) + cursor.EntityTypes = make([]persist.FeedEntityType, len(interleaved)) + cursor.EntityIDs = make([]persist.DBID, len(interleaved)) for i, e := range interleaved { idx := len(interleaved) - i - 1 - entityTypes[idx] = persist.FeedEntityType(e.FeedEntityType) - entityIDs[idx] = e.ID + cursor.EntityTypes[idx] = persist.FeedEntityType(e.FeedEntityType) + cursor.EntityIDs[idx] = e.ID } } queryFunc := func(params feedPagingParams) ([]any, error) { - if hasCursors { - entityTypes = params.EntityTypes - entityIDs = params.EntityIDs - } - for i, id := range entityIDs { + for i, id := range cursor.EntityIDs { entityIDToPos[id] = i } // Filter slices in place if !includePosts { idx := 0 - for i := range entityTypes { - if entityTypes[i] != persist.PostTypeTag { - entityTypes[idx] = entityTypes[i] - entityIDs[idx] = entityIDs[i] + for i := range cursor.EntityTypes { + if cursor.EntityTypes[i] != persist.PostTypeTag { + cursor.EntityTypes[idx] = cursor.EntityTypes[i] + cursor.EntityIDs[idx] = cursor.EntityIDs[i] idx++ } } - entityTypes = entityTypes[:idx] - entityIDs = entityIDs[:idx] + cursor.EntityTypes = cursor.EntityTypes[:idx] + cursor.EntityIDs = cursor.EntityIDs[:idx] } - return loadFeedEntities(ctx, api.loaders, entityTypes, entityIDs) + return loadFeedEntities(ctx, api.loaders, cursor.EntityTypes, cursor.EntityIDs) } - cursorFunc := func(node any) (int, []persist.FeedEntityType, []persist.DBID, error) { + cursorFunc := func(node any) (int64, []persist.FeedEntityType, []persist.DBID, error) { _, id, err := feedCursor(node) pos, ok := entityIDToPos[id] if !ok { panic(fmt.Sprintf("could not find position for id=%s", id)) } - return pos, entityTypes, entityIDs, err + return int64(pos), cursor.EntityTypes, cursor.EntityIDs, err } paginator.QueryFunc = queryFunc @@ -822,58 +825,7 @@ type feedPagingParams struct { type feedPaginator struct { QueryFunc func(params feedPagingParams) ([]any, error) - CursorFunc func(node any) (pos int, feedEntityType []persist.FeedEntityType, ids []persist.DBID, err error) -} - -func (p *feedPaginator) encodeCursor(pos int, typ []persist.FeedEntityType, ids []persist.DBID) (string, error) { - if len(typ) != len(ids) { - panic("type and ids must be the same length") - } - encoder := newCursorEncoder() - encoder.appendInt64(int64(pos)) - encoder.appendInt64(int64(len(ids))) - for i := range typ { - encoder.appendInt64(int64(typ[i])) - encoder.appendDBID(ids[i]) - } - return encoder.AsBase64(), nil -} - -func (p *feedPaginator) decodeCursor(cursor string) (pos int, typs []persist.FeedEntityType, ids []persist.DBID, err error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return 0, nil, nil, err - } - - curPos, err := decoder.readInt64() - if err != nil { - return 0, nil, nil, err - } - - totalItems, err := decoder.readInt64() - if err != nil { - return 0, nil, nil, err - } - - typs = make([]persist.FeedEntityType, totalItems) - ids = make([]persist.DBID, totalItems) - - for i := 0; i < int(totalItems); i++ { - typ, err := decoder.readInt64() - if err != nil { - return 0, nil, nil, err - } - - id, err := decoder.readDBID() - if err != nil { - return 0, nil, nil, err - } - - typs[i] = persist.FeedEntityType(typ) - ids[i] = id - } - - return int(curPos), typs, ids, nil + CursorFunc func(node any) (pos int64, feedEntityType []persist.FeedEntityType, ids []persist.DBID, err error) } func (p *feedPaginator) paginate(before, after *string, first, last *int) ([]any, PageInfo, error) { @@ -881,25 +833,25 @@ func (p *feedPaginator) paginate(before, after *string, first, last *int) ([]any CurBeforePos: defaultCursorBeforePosition, CurAfterPos: defaultCursorAfterPosition, } + beforeCur := feedPositionCursor{} + afterCur := feedPositionCursor{} if before != nil { - curBeforePos, typs, ids, err := p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, PageInfo{}, err } - args.CurBeforePos = curBeforePos - args.EntityTypes = typs - args.EntityIDs = ids + args.CurBeforePos = int(beforeCur.CurrentPosition) + args.EntityTypes = beforeCur.EntityTypes + args.EntityIDs = beforeCur.EntityIDs } if after != nil { - curAfterPos, typs, ids, err := p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, PageInfo{}, err } - args.CurAfterPos = curAfterPos - args.EntityTypes = typs - args.EntityIDs = ids + args.CurAfterPos = int(afterCur.CurrentPosition) + args.EntityTypes = afterCur.EntityTypes + args.EntityIDs = afterCur.EntityIDs } results, err := p.QueryFunc(args) @@ -907,15 +859,7 @@ func (p *feedPaginator) paginate(before, after *string, first, last *int) ([]any return nil, PageInfo{}, err } - cursorFunc := func(node any) (string, error) { - pos, typs, ids, err := p.CursorFunc(node) - if err != nil { - return "", err - } - return p.encodeCursor(pos, typs, ids) - } - - return pageFrom(results, nil, cursorFunc, before, after, first, last) + return pageFrom(results, nil, cursors.NewFeedPositionCursorer(p.CursorFunc), before, after, first, last) } type feedCache struct { @@ -938,8 +882,12 @@ func newFeedCache(cache *redis.Cache, includePosts bool, f func(context.Context) if err != nil { return nil, err } - var p feedPaginator - b, err := p.encodeCursor(0, types, ids) + cur := feedPositionCursor{ + CurrentPosition: 0, + EntityTypes: types, + EntityIDs: ids, + } + b, err := cur.Pack() return []byte(b), err }, }, @@ -951,9 +899,9 @@ func (f feedCache) Load(ctx context.Context) ([]persist.FeedEntityType, []persis if err != nil { return nil, nil, err } - var p feedPaginator - _, types, ids, err := p.decodeCursor(string(b)) - return types, ids, err + var cur feedPositionCursor + err = cur.Unpack(string(b)) + return cur.EntityTypes, cur.EntityIDs, err } func min(a, b int) int { diff --git a/publicapi/interaction.go b/publicapi/interaction.go index 2a6483b73..5c337b2c7 100644 --- a/publicapi/interaction.go +++ b/publicapi/interaction.go @@ -169,9 +169,9 @@ func (api InteractionAPI) PaginateInteractionsByFeedEventID(ctx context.Context, return total, err } - cursorFunc := func(i interface{}) (int32, time.Time, persist.DBID, error) { + cursorFunc := func(i interface{}) (int64, time.Time, persist.DBID, error) { if row, ok := i.(db.PaginateInteractionsByFeedEventIDBatchRow); ok { - return row.Tag, row.CreatedAt, row.ID, nil + return int64(row.Tag), row.CreatedAt, row.ID, nil } return 0, time.Time{}, "", fmt.Errorf("interface{} is not the correct type") } @@ -263,9 +263,9 @@ func (api InteractionAPI) PaginateInteractionsByPostID(ctx context.Context, post return total, err } - cursorFunc := func(i interface{}) (int32, time.Time, persist.DBID, error) { + cursorFunc := func(i interface{}) (int64, time.Time, persist.DBID, error) { if row, ok := i.(db.PaginateInteractionsByPostIDBatchRow); ok { - return row.Tag, row.CreatedAt, row.ID, nil + return int64(row.Tag), row.CreatedAt, row.ID, nil } return 0, time.Time{}, "", fmt.Errorf("interface{} is not the correct type") } diff --git a/publicapi/pagination.go b/publicapi/pagination.go index 228e90601..348a21de9 100644 --- a/publicapi/pagination.go +++ b/publicapi/pagination.go @@ -61,32 +61,40 @@ func validatePaginationParams(validator *validator.Validate, first *int, last *i return nil } -func pageFrom[T any](allEdges []T, countF func() (int, error), cursorF func(T) (string, error), before, after *string, first, last *int) ([]T, PageInfo, error) { - cursorEdges, err := applyCursors(allEdges, cursorF, before, after) +func pageFrom[T any](allEdges []T, countF func() (int, error), cF cursorable, before, after *string, first, last *int) ([]T, PageInfo, error) { + cursorEdges, err := applyCursors(allEdges, cF, before, after) if err != nil { return nil, PageInfo{}, err } - edgesPaged, err := pageEdgesFrom(cursorEdges, cursorF, before, after, first, last) + edgesPaged, err := pageEdgesFrom(cursorEdges, before, after, first, last) if err != nil { return nil, PageInfo{}, err } - pageInfo, err := pageInfoFrom(cursorEdges, edgesPaged, countF, cursorF, before, after, first, last) + pageInfo, err := pageInfoFrom(cursorEdges, edgesPaged, countF, cF, before, after, first, last) return edgesPaged, pageInfo, err } -func pageInfoFrom[T any](cursorEdges, edgesPaged []T, countF func() (int, error), cursorF func(T) (string, error), before, after *string, first, last *int) (pageInfo PageInfo, err error) { +func packNode(cF cursorable, node any) (string, error) { + cursor, err := cF(node) + if err != nil { + return "", err + } + return cursor.Pack() +} + +func pageInfoFrom[T any](cursorEdges, edgesPaged []T, countF func() (int, error), cF cursorable, before, after *string, first, last *int) (pageInfo PageInfo, err error) { if len(edgesPaged) > 0 { firstNode := edgesPaged[0] lastNode := edgesPaged[len(edgesPaged)-1] - pageInfo.StartCursor, err = cursorF(firstNode) + pageInfo.StartCursor, err = packNode(cF, firstNode) if err != nil { return PageInfo{}, err } - pageInfo.EndCursor, err = cursorF(lastNode) + pageInfo.EndCursor, err = packNode(cF, lastNode) if err != nil { return PageInfo{}, err } @@ -113,7 +121,7 @@ func pageInfoFrom[T any](cursorEdges, edgesPaged []T, countF func() (int, error) return pageInfo, nil } -func pageEdgesFrom[T any](edges []T, cursorF func(T) (string, error), before, after *string, first, last *int) ([]T, error) { +func pageEdgesFrom[T any](edges []T, before, after *string, first, last *int) ([]T, error) { if first != nil && len(edges) > *first { return edges[:*first], nil } @@ -125,12 +133,12 @@ func pageEdgesFrom[T any](edges []T, cursorF func(T) (string, error), before, af return edges, nil } -func applyCursors[T any](allEdges []T, cursorF func(T) (string, error), before, after *string) ([]T, error) { +func applyCursors[T any](allEdges []T, cursorable func(any) (cursorer, error), before, after *string) ([]T, error) { edges := append([]T{}, allEdges...) if after != nil { for i, edge := range edges { - cur, err := cursorF(edge) + cur, err := packNode(cursorable, edge) if err != nil { return nil, err } @@ -143,7 +151,7 @@ func applyCursors[T any](allEdges []T, cursorF func(T) (string, error), before, if before != nil { for i, edge := range edges { - cur, err := cursorF(edge) + cur, err := packNode(cursorable, edge) if err != nil { return nil, err } @@ -164,8 +172,8 @@ type keysetPaginator struct { // QueryFunc returns paginated results for the given paging parameters QueryFunc func(limit int32, pagingForward bool) (nodes []interface{}, err error) - // CursorFunc returns a cursor string for the given node value - CursorFunc func(node interface{}) (cursor string, err error) + // Cursorable produces a cursorer for encoding nodes to cursor strings + Cursorable cursorable // CountFunc returns the total number of items that can be paginated. May be nil, in which // case the resulting PageInfo will omit the total field. @@ -198,7 +206,7 @@ func (p *keysetPaginator) paginate(before *string, after *string, first *int, la } } - return pageFrom(results, p.CountFunc, p.CursorFunc, before, after, first, last) + return pageFrom(results, p.CountFunc, p.Cursorable, before, after, first, last) } // timeIDPaginator paginates results using a cursor with a time.Time and a persist.DBID. @@ -227,80 +235,38 @@ type timeIDPagingParams struct { PagingForward bool } -func (p *timeIDPaginator) encodeCursor(t time.Time, id persist.DBID) (string, error) { - encoder := newCursorEncoder() - if err := encoder.appendTime(t); err != nil { - return "", err - } - encoder.appendDBID(id) - return encoder.AsBase64(), nil -} - -func (p *timeIDPaginator) decodeCursor(cursor string) (time.Time, persist.DBID, error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return time.Time{}, "", err - } - - t, err := decoder.readTime() - if err != nil { - return time.Time{}, "", err - } - - id, err := decoder.readDBID() - if err != nil { - return time.Time{}, "", err - } - - return t, id, nil -} - -func (p *timeIDPaginator) paginate(before *string, after *string, first *int, last *int) ([]interface{}, PageInfo, error) { - queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { - curBeforeTime := defaultCursorBeforeTime - curBeforeID := persist.DBID("") - curAfterTime := defaultCursorAfterTime - curAfterID := persist.DBID("") +func (p *timeIDPaginator) paginate(before *string, after *string, first *int, last *int) ([]any, PageInfo, error) { + queryFunc := func(limit int32, pagingForward bool) ([]any, error) { + beforeCur := timeIDCursor{Time: defaultCursorBeforeTime, ID: defaultCursorBeforeID} + afterCur := timeIDCursor{Time: defaultCursorAfterTime, ID: defaultCursorAfterID} - var err error if before != nil { - curBeforeTime, curBeforeID, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } } if after != nil { - curAfterTime, curAfterID, err = p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } } queryParams := timeIDPagingParams{ Limit: limit, - CursorBeforeTime: curBeforeTime, - CursorBeforeID: curBeforeID, - CursorAfterTime: curAfterTime, - CursorAfterID: curAfterID, + CursorBeforeTime: beforeCur.Time, + CursorBeforeID: beforeCur.ID, + CursorAfterTime: afterCur.Time, + CursorAfterID: afterCur.ID, PagingForward: pagingForward, } return p.QueryFunc(queryParams) } - cursorFunc := func(node interface{}) (string, error) { - nodeTime, nodeID, err := p.CursorFunc(node) - if err != nil { - return "", err - } - - return p.encodeCursor(nodeTime, nodeID) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewTimeIDCursorer(p.CursorFunc), CountFunc: p.CountFunc, } @@ -313,52 +279,36 @@ func (p *sharedFollowersPaginator) paginate(before *string, after *string, first queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { // The shared followers query orders results in descending order when // paging forward (vs. ascending order which is more typical). - curBeforeTime := time.Date(1970, 1, 1, 1, 1, 1, 1, time.UTC) - curBeforeID := persist.DBID("") - curAfterTime := time.Date(3000, 1, 1, 1, 1, 1, 1, time.UTC) - curAfterID := persist.DBID("") + beforeCur := timeIDCursor{Time: time.Date(1970, 1, 1, 1, 1, 1, 1, time.UTC), ID: defaultCursorBeforeID} + afterCur := timeIDCursor{Time: time.Date(3000, 1, 1, 1, 1, 1, 1, time.UTC), ID: defaultCursorAfterID} - var err error if before != nil { - curBeforeTime, curBeforeID, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } } if after != nil { - loc, _ := time.LoadLocation("UTC") - curAfterTime, curAfterID, err = p.decodeCursor(*after) - curAfterTime = curAfterTime.In(loc) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } } queryParams := timeIDPagingParams{ Limit: limit, - CursorBeforeTime: curBeforeTime, - CursorBeforeID: curBeforeID, - CursorAfterTime: curAfterTime, - CursorAfterID: curAfterID, + CursorBeforeTime: beforeCur.Time, + CursorBeforeID: beforeCur.ID, + CursorAfterTime: afterCur.Time, + CursorAfterID: afterCur.ID, PagingForward: pagingForward, } return p.QueryFunc(queryParams) } - cursorFunc := func(node interface{}) (string, error) { - nodeTime, nodeID, err := p.CursorFunc(node) - if err != nil { - return "", err - } - - return p.encodeCursor(nodeTime, nodeID) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewTimeIDCursorer(p.CursorFunc), CountFunc: p.CountFunc, } @@ -387,105 +337,49 @@ type sharedContractsPaginator struct { // * A bool indicating that userB displays the contract on their gallery // * An int indicating how many tokens userA owns for a contract // * A DBID indicating the ID of the contract - CursorFunc func(node interface{}) (bool, bool, int, persist.DBID, error) + CursorFunc func(node interface{}) (bool, bool, int64, persist.DBID, error) // CountFunc returns the total number of items that can be paginated. May be nil, in which // case the resulting PageInfo will omit the total field. CountFunc func() (count int, err error) } -func (p *sharedContractsPaginator) encodeCursor(displayedA, displayedB bool, i int, contractID persist.DBID) (string, error) { - encoder := newCursorEncoder() - encoder.appendBool(displayedA) - encoder.appendBool(displayedB) - encoder.appendInt64(int64(i)) - encoder.appendDBID(contractID) - return encoder.AsBase64(), nil -} - -func (p *sharedContractsPaginator) decodeCursor(cursor string) (bool, bool, int, persist.DBID, error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return false, false, 0, "", err - } - - displayedA, err := decoder.readBool() - if err != nil { - return false, false, 0, "", err - } - - displayedB, err := decoder.readBool() - if err != nil { - return false, false, 0, "", err - } - - ownedCount, err := decoder.readInt64() - if err != nil { - return false, false, 0, "", err - } - - contractID, err := decoder.readDBID() - if err != nil { - return false, false, 0, "", err - } - - return displayedA, displayedB, int(ownedCount), contractID, nil -} - func (p *sharedContractsPaginator) paginate(before *string, after *string, first *int, last *int) ([]interface{}, PageInfo, error) { queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { - cursorBeforeDisplayedByUserA := false - cursorBeforeDisplayedByUserB := false - cursorBeforeOwnedCount := -1 - cursorBeforeContractID := defaultCursorBeforeID - cursorAfterDisplayedByUserA := true - cursorAfterDisplayedByUserB := true - cursorAfterOwnedCount := math.MaxInt32 - cursorAfterContractID := defaultCursorAfterID - - var err error + beforeCur := boolBootIntIDCursor{Bool1: false, Bool2: false, Int: -1, ID: defaultCursorBeforeID} + afterCur := boolBootIntIDCursor{Bool1: true, Bool2: true, Int: math.MaxInt32, ID: defaultCursorAfterID} + if before != nil { - cursorBeforeDisplayedByUserA, cursorBeforeDisplayedByUserB, cursorBeforeOwnedCount, cursorBeforeContractID, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } } if after != nil { - cursorAfterDisplayedByUserA, cursorAfterDisplayedByUserB, cursorAfterOwnedCount, cursorAfterContractID, err = p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } } queryParams := sharedContractsPaginatorParams{ Limit: limit, - CursorBeforeDisplayedByUserA: cursorBeforeDisplayedByUserA, - CursorBeforeDisplayedByUserB: cursorBeforeDisplayedByUserB, - CursorBeforeOwnedCount: cursorBeforeOwnedCount, - CursorBeforeContractID: cursorBeforeContractID, - CursorAfterDisplayedByUserA: cursorAfterDisplayedByUserA, - CursorAfterDisplayedByUserB: cursorAfterDisplayedByUserB, - CursorAfterOwnedCount: cursorAfterOwnedCount, - CursorAfterContractID: cursorAfterContractID, + CursorBeforeDisplayedByUserA: beforeCur.Bool1, + CursorBeforeDisplayedByUserB: beforeCur.Bool2, + CursorBeforeOwnedCount: int(beforeCur.Int), + CursorBeforeContractID: beforeCur.ID, + CursorAfterDisplayedByUserA: afterCur.Bool1, + CursorAfterDisplayedByUserB: afterCur.Bool2, + CursorAfterOwnedCount: int(afterCur.Int), + CursorAfterContractID: afterCur.ID, PagingForward: pagingForward, } return p.QueryFunc(queryParams) } - cursorFunc := func(node interface{}) (string, error) { - displayedUserA, displayedUserB, ownedCount, contractID, err := p.CursorFunc(node) - if err != nil { - return "", err - } - - return p.encodeCursor(displayedUserA, displayedUserB, ownedCount, contractID) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewBoolBoolIntIDCursorer(p.CursorFunc), CountFunc: p.CountFunc, } @@ -515,90 +409,40 @@ type boolTimeIDPaginator struct { CountFunc func() (count int, err error) } -func (p *boolTimeIDPaginator) encodeCursor(b bool, t time.Time, id persist.DBID) (string, error) { - encoder := newCursorEncoder() - encoder.appendBool(b) - if err := encoder.appendTime(t); err != nil { - return "", err - } - encoder.appendDBID(id) - return encoder.AsBase64(), nil -} - -func (p *boolTimeIDPaginator) decodeCursor(cursor string) (bool, time.Time, persist.DBID, error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return false, time.Time{}, "", err - } - - b, err := decoder.readBool() - if err != nil { - return false, time.Time{}, "", err - } - - t, err := decoder.readTime() - if err != nil { - return false, time.Time{}, "", err - } - - id, err := decoder.readDBID() - if err != nil { - return false, time.Time{}, "", err - } - - return b, t, id, nil -} - func (p *boolTimeIDPaginator) paginate(before *string, after *string, first *int, last *int) ([]interface{}, PageInfo, error) { queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { - curBeforeTime := defaultCursorBeforeTime - curBeforeID := defaultCursorBeforeID - curAfterTime := defaultCursorAfterTime - curAfterID := defaultCursorAfterID - curBeforeBool := true - curAfterBool := false - - var err error + beforeCur := boolTimeIDCursor{Bool: true, Time: defaultCursorBeforeTime, ID: defaultCursorBeforeID} + afterCur := boolTimeIDCursor{Bool: false, Time: defaultCursorAfterTime, ID: defaultCursorAfterID} + if before != nil { - curBeforeBool, curBeforeTime, curBeforeID, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } } if after != nil { - curAfterBool, curAfterTime, curAfterID, err = p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } } queryParams := boolTimeIDPagingParams{ Limit: limit, - CursorBeforeBool: curBeforeBool, - CursorBeforeTime: curBeforeTime, - CursorBeforeID: curBeforeID, - CursorAfterBool: curAfterBool, - CursorAfterTime: curAfterTime, - CursorAfterID: curAfterID, + CursorBeforeBool: beforeCur.Bool, + CursorBeforeTime: beforeCur.Time, + CursorBeforeID: beforeCur.ID, + CursorAfterBool: afterCur.Bool, + CursorAfterTime: afterCur.Time, + CursorAfterID: afterCur.ID, PagingForward: pagingForward, } return p.QueryFunc(queryParams) } - cursorFunc := func(node interface{}) (string, error) { - nodeBool, nodeTime, nodeID, err := p.CursorFunc(node) - if err != nil { - return "", err - } - - return p.encodeCursor(nodeBool, nodeTime, nodeID) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewBoolTimeIDCursorer(p.CursorFunc), CountFunc: p.CountFunc, } @@ -626,78 +470,44 @@ type lexicalPagingParams struct { PagingForward bool } -func (p *lexicalPaginator) encodeCursor(sortKey string, id persist.DBID) (string, error) { - encoder := newCursorEncoder() - encoder.appendString(sortKey) - encoder.appendDBID(id) - return encoder.AsBase64(), nil -} - -func (p *lexicalPaginator) decodeCursor(cursor string) (string, persist.DBID, error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return "", "", err - } - - sortKey, err := decoder.readString() - if err != nil { - return "", "", err - } - - id, err := decoder.readDBID() - if err != nil { - return "", "", err - } - - return sortKey, id, nil -} - func (p *lexicalPaginator) paginate(before *string, after *string, first *int, last *int) ([]interface{}, PageInfo, error) { queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { - curBeforeKey := defaultCursorBeforeKey - curBeforeID := defaultCursorBeforeID - curAfterKey := defaultCursorAfterKey - curAfterID := defaultCursorAfterID + beforeCur := stringIDCursor{ + String: defaultCursorBeforeKey, + ID: defaultCursorBeforeID, + } + afterCur := stringIDCursor{ + String: defaultCursorAfterKey, + ID: defaultCursorAfterID, + } - var err error if before != nil { - curBeforeKey, curBeforeID, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } } if after != nil { - curAfterKey, curAfterID, err = p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } } queryParams := lexicalPagingParams{ Limit: limit, - CursorBeforeKey: curBeforeKey, - CursorBeforeID: curBeforeID, - CursorAfterKey: curAfterKey, - CursorAfterID: curAfterID, + CursorBeforeKey: beforeCur.String, + CursorBeforeID: beforeCur.ID, + CursorAfterKey: afterCur.String, + CursorAfterID: afterCur.ID, PagingForward: pagingForward, } return p.QueryFunc(queryParams) } - cursorFunc := func(node interface{}) (string, error) { - nodeKey, nodeID, err := p.CursorFunc(node) - if err != nil { - return "", err - } - - return p.encodeCursor(nodeKey, nodeID) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewStringIDCursorer(p.CursorFunc), CountFunc: p.CountFunc, } @@ -710,7 +520,7 @@ type positionPaginator struct { QueryFunc func(params positionPagingParams) ([]any, error) // CursorFunc returns the current position and a fixed slice of DBIDs that will be encoded into a cursor string - CursorFunc func(node interface{}) (int, []persist.DBID, error) + CursorFunc func(node interface{}) (int64, []persist.DBID, error) // CountFunc returns the total number of items that can be paginated. May be nil, in which // case the resulting PageInfo will omit the total field. @@ -742,73 +552,36 @@ type positionPagingParams struct { IDs []persist.DBID } -func (p *positionPaginator) encodeCursor(position int, ids []persist.DBID) (string, error) { - encoder := newCursorEncoder() - encoder.appendInt64(int64(position)) - encoder.appendInt64(int64(len(ids))) - for _, id := range ids { - encoder.appendDBID(id) - } - return encoder.AsBase64(), nil -} - -func (p *positionPaginator) decodeCursor(cursor string) (int, []persist.DBID, error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return 0, nil, err - } - - position, err := decoder.readInt64() - if err != nil { - return 0, nil, err - } - - totalItems, err := decoder.readInt64() - if err != nil { - return 0, nil, err - } - - ids := make([]persist.DBID, totalItems) - for i := int64(0); i < totalItems; i++ { - id, err := decoder.readDBID() - if err != nil { - return 0, nil, err - } - ids[i] = id - } - - return int(position), ids, nil -} - func (p *positionPaginator) paginate(before *string, after *string, first *int, last *int, opts ...func(*positionPaginatorArgs)) ([]interface{}, PageInfo, error) { queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { - var err error var ids []persist.DBID - var curBeforePos int - var curAfterPos int - args := positionPaginatorArgs{} - args.CurBeforePos = defaultCursorBeforePosition - args.CurAfterPos = defaultCursorAfterPosition + args := positionPaginatorArgs{ + CurBeforePos: defaultCursorBeforePosition, + CurAfterPos: defaultCursorAfterPosition, + } + + beforeCur := positionCursor{} + afterCur := positionCursor{} for _, opt := range opts { opt(&args) } if before != nil { - curBeforePos, ids, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } - args.CurBeforePos = curBeforePos + args.CurBeforePos = int(beforeCur.CurrentPosition) + ids = beforeCur.IDs } if after != nil { - curAfterPos, ids, err = p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } - args.CurAfterPos = curAfterPos + args.CurAfterPos = int(afterCur.CurrentPosition) + ids = afterCur.IDs } queryParams := positionPagingParams{ @@ -822,17 +595,9 @@ func (p *positionPaginator) paginate(before *string, after *string, first *int, return p.QueryFunc(queryParams) } - cursorFunc := func(node any) (string, error) { - pos, nodeIDs, err := p.CursorFunc(node) - if err != nil { - return "", err - } - return p.encodeCursor(pos, nodeIDs) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewPositionCursorer(p.CursorFunc), CountFunc: p.CountFunc, } @@ -842,7 +607,7 @@ func (p *positionPaginator) paginate(before *string, after *string, first *int, type intTimeIDPaginator struct { QueryFunc func(params intTimeIDPagingParams) ([]interface{}, error) - CursorFunc func(node interface{}) (int32, time.Time, persist.DBID, error) + CursorFunc func(node interface{}) (int64, time.Time, persist.DBID, error) // CountFunc returns the total number of items that can be paginated. May be nil, in which // case the resulting PageInfo will omit the total field. @@ -861,96 +626,48 @@ type intTimeIDPagingParams struct { PagingForward bool } -func (p *intTimeIDPaginator) encodeCursor(i int32, t time.Time, id persist.DBID) (string, error) { - encoder := newCursorEncoder() - encoder.appendInt64(int64(i)) - if err := encoder.appendTime(t); err != nil { - return "", err - } - encoder.appendDBID(id) - return encoder.AsBase64(), nil -} - -func (p *intTimeIDPaginator) decodeCursor(cursor string) (int32, time.Time, persist.DBID, error) { - decoder, err := newCursorDecoder(cursor) - if err != nil { - return 0, time.Time{}, "", err - } - - i, err := decoder.readInt64() - if err != nil { - return 0, time.Time{}, "", err - } - - t, err := decoder.readTime() - if err != nil { - return 0, time.Time{}, "", err - } - - id, err := decoder.readDBID() - if err != nil { - return 0, time.Time{}, "", err - } - - return int32(i), t, id, nil -} - func (p *intTimeIDPaginator) paginate(before *string, after *string, first *int, last *int) ([]interface{}, PageInfo, error) { queryFunc := func(limit int32, pagingForward bool) ([]interface{}, error) { - curBeforeInt := int32(math.MaxInt32) - curBeforeTime := defaultCursorBeforeTime - curBeforeID := persist.DBID("") - curAfterInt := int32(0) - curAfterTime := defaultCursorAfterTime - curAfterID := persist.DBID("") - - var err error + beforeCur := intTimeIDCursor{Int: math.MaxInt32, Time: defaultCursorBeforeTime, ID: defaultCursorBeforeID} + afterCur := intTimeIDCursor{Int: 0, Time: defaultCursorAfterTime, ID: defaultCursorAfterID} + if before != nil { - curBeforeInt, curBeforeTime, curBeforeID, err = p.decodeCursor(*before) - if err != nil { + if err := beforeCur.Unpack(*before); err != nil { return nil, err } } if after != nil { - curAfterInt, curAfterTime, curAfterID, err = p.decodeCursor(*after) - if err != nil { + if err := afterCur.Unpack(*after); err != nil { return nil, err } } queryParams := intTimeIDPagingParams{ Limit: limit, - CursorBeforeInt: curBeforeInt, - CursorBeforeTime: curBeforeTime, - CursorBeforeID: curBeforeID, - CursorAfterInt: curAfterInt, - CursorAfterTime: curAfterTime, - CursorAfterID: curAfterID, + CursorBeforeInt: int32(beforeCur.Int), + CursorBeforeTime: beforeCur.Time, + CursorBeforeID: beforeCur.ID, + CursorAfterInt: int32(afterCur.Int), + CursorAfterTime: afterCur.Time, + CursorAfterID: afterCur.ID, PagingForward: pagingForward, } return p.QueryFunc(queryParams) } - cursorFunc := func(node interface{}) (string, error) { - nodeInt, nodeTime, nodeID, err := p.CursorFunc(node) - if err != nil { - return "", err - } - - return p.encodeCursor(nodeInt, nodeTime, nodeID) - } - paginator := keysetPaginator{ QueryFunc: queryFunc, - CursorFunc: cursorFunc, + Cursorable: cursors.NewIntTimeIDCursorer(p.CursorFunc), CountFunc: p.CountFunc, } return paginator.paginate(before, after, first, last) } +//------------------------------------------------------------------------------ + type cursorEncoder struct { buffer []byte } @@ -1021,6 +738,10 @@ func (e *cursorEncoder) appendInt64(i int64) { e.buffer = append(e.buffer, buf[:bytesWritten]...) } +func (e *cursorEncoder) appendFeedEntityType(i persist.FeedEntityType) { + e.appendInt64(int64(i)) +} + type cursorDecoder struct { reader *bytes.Reader } @@ -1115,3 +836,347 @@ func (d *cursorDecoder) readUInt64() (uint64, error) { func (d *cursorDecoder) readInt64() (int64, error) { return binary.ReadVarint(d.reader) } + +// readFeedEntityType reads FeedEntityType from the underlying reader and advances the stream +func (d *cursorDecoder) readFeedEntityType() (persist.FeedEntityType, error) { + i, err := binary.ReadVarint(d.reader) + if err != nil { + return 0, err + } + return persist.FeedEntityType(i), nil +} + +//------------------------------------------------------------------------------ + +type cursorer interface { + Pack() (string, error) + Unpack(string) error +} +type cursorable func(any) (cursorer, error) +type curs struct{} + +var cursors curs + +func (curs) NewTimeIDCursorer(f func(any) (time.Time, persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur timeIDCursor + cur.Time, cur.ID, err = f(node) + return &cur, err + } +} + +func (curs) NewBoolBoolIntIDCursorer(f func(any) (bool, bool, int64, persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur boolBootIntIDCursor + cur.Bool1, cur.Bool2, cur.Int, cur.ID, err = f(node) + return &cur, err + } +} + +func (curs) NewBoolTimeIDCursorer(f func(any) (bool, time.Time, persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur boolTimeIDCursor + cur.Bool, cur.Time, cur.ID, err = f(node) + return &cur, err + } +} + +func (curs) NewStringIDCursorer(f func(any) (string, persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur stringIDCursor + cur.String, cur.ID, err = f(node) + return &cur, err + } +} + +func (curs) NewIntTimeIDCursorer(f func(any) (int64, time.Time, persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur intTimeIDCursor + cur.Int, cur.Time, cur.ID, err = f(node) + return &cur, err + } +} + +func (curs) NewPositionCursorer(f func(any) (int64, []persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur positionCursor + cur.CurrentPosition, cur.IDs, err = f(node) + return &cur, err + } +} + +func (curs) NewFeedPositionCursorer(f func(any) (int64, []persist.FeedEntityType, []persist.DBID, error)) cursorable { + return func(node any) (c cursorer, err error) { + var cur feedPositionCursor + cur.CurrentPosition, cur.EntityTypes, cur.EntityIDs, err = f(node) + return &cur, err + } +} + +//------------------------------------------------------------------------------ + +type timeIDCursor struct { + Time time.Time + ID persist.DBID +} + +func (c timeIDCursor) Pack() (string, error) { return pack(c.Time, c.ID) } +func (c *timeIDCursor) Unpack(s string) error { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + c.Time, err = d.readTime() + if err != nil { + return err + } + c.ID, err = d.readDBID() + return err +} + +//------------------------------------------------------------------------------ + +type boolBootIntIDCursor struct { + Bool1 bool + Bool2 bool + Int int64 + ID persist.DBID +} + +func (c boolBootIntIDCursor) Pack() (string, error) { return pack(c.Bool1, c.Bool2, c.Int, c.ID) } +func (c *boolBootIntIDCursor) Unpack(s string) error { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + c.Bool1, err = d.readBool() + if err != nil { + return err + } + c.Bool2, err = d.readBool() + if err != nil { + return err + } + c.Int, err = d.readInt64() + if err != nil { + return err + } + c.ID, err = d.readDBID() + return err +} + +//------------------------------------------------------------------------------ + +type boolTimeIDCursor struct { + Bool bool + Time time.Time + ID persist.DBID +} + +func (c boolTimeIDCursor) Pack() (string, error) { return pack(c.Bool, c.Time, c.ID) } +func (c *boolTimeIDCursor) Unpack(s string) error { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + c.Bool, err = d.readBool() + if err != nil { + return err + } + c.Time, err = d.readTime() + if err != nil { + return err + } + c.ID, err = d.readDBID() + return err +} + +//------------------------------------------------------------------------------ + +type stringIDCursor struct { + String string + ID persist.DBID +} + +func (c stringIDCursor) Pack() (string, error) { return pack(c.String, c.ID) } +func (c *stringIDCursor) Unpack(s string) error { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + c.String, err = d.readString() + if err != nil { + return err + } + c.ID, err = d.readDBID() + return err +} + +//------------------------------------------------------------------------------ + +type intTimeIDCursor struct { + Int int64 + Time time.Time + ID persist.DBID +} + +func (c intTimeIDCursor) Pack() (string, error) { return pack(c.Int, c.Time, c.ID) } +func (c *intTimeIDCursor) Unpack(s string) error { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + c.Int, err = d.readInt64() + if err != nil { + return err + } + c.Time, err = d.readTime() + if err != nil { + return err + } + c.ID, err = d.readDBID() + return err +} + +//------------------------------------------------------------------------------ + +type feedPositionCursor struct { + CurrentPosition int64 + EntityTypes []persist.FeedEntityType + EntityIDs []persist.DBID +} + +func (c feedPositionCursor) Pack() (string, error) { + return pack(c.CurrentPosition, c.EntityTypes, c.EntityIDs) +} + +func (c *feedPositionCursor) Unpack(s string) (err error) { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + + c.CurrentPosition, err = d.readInt64() + if err != nil { + return err + } + + c.EntityTypes, err = unpackSlice[persist.FeedEntityType](&d, func(d *cursorDecoder) (persist.FeedEntityType, error) { + return d.readFeedEntityType() + }) + if err != nil { + return err + } + + c.EntityIDs, err = unpackSlice[persist.DBID](&d, func(d *cursorDecoder) (persist.DBID, error) { + return d.readDBID() + }) + + return err +} + +//------------------------------------------------------------------------------ + +type positionCursor struct { + CurrentPosition int64 + IDs []persist.DBID +} + +func (c positionCursor) Pack() (string, error) { return pack(c.CurrentPosition, c.IDs) } +func (c *positionCursor) Unpack(s string) (err error) { + d, err := newCursorDecoder(s) + if err != nil { + return err + } + + c.CurrentPosition, err = d.readInt64() + if err != nil { + return err + } + + c.IDs, err = unpackSlice[persist.DBID](&d, func(d *cursorDecoder) (persist.DBID, error) { + return d.readDBID() + }) + + return err +} + +//------------------------------------------------------------------------------ + +func pack(vals ...any) (string, error) { + e := newCursorEncoder() + + if err := packVals(&e, vals...); err != nil { + return "", err + } + + return e.AsBase64(), nil +} + +func packVal(e *cursorEncoder, val any) error { + switch v := val.(type) { + case bool: + e.appendBool(v) + case string: + e.appendString(v) + case persist.DBID: + e.appendDBID(v) + case uint64: + e.appendUInt64(v) + case int64: + e.appendInt64(v) + case int: + e.appendInt64(int64(v)) + case time.Time: + if err := e.appendTime(v); err != nil { + return err + } + case persist.FeedEntityType: + e.appendFeedEntityType(v) + case []persist.DBID: + if err := packSlice(e, v); err != nil { + return err + } + case []persist.FeedEntityType: + if err := packSlice(e, v); err != nil { + return err + } + default: + panic(fmt.Sprintf("unknown cursor type: %T", v)) + } + return nil +} + +func packVals[T any](e *cursorEncoder, vals ...T) error { + for _, val := range vals { + if err := packVal(e, val); err != nil { + return err + } + } + return nil +} + +// Encode the length of the slice as an int64, then encode each val +func packSlice[T any](e *cursorEncoder, s []T) error { + e.appendInt64(int64(len(s))) + return packVals(e, s...) +} + +func unpackSlice[T any](d *cursorDecoder, f func(d *cursorDecoder) (T, error)) ([]T, error) { + l, err := d.readInt64() + if err != nil { + return nil, err + } + + items := make([]T, l) + + for i := int64(0); i < l; i++ { + id, err := f(d) + if err != nil { + return nil, err + } + items[i] = id + } + + return items, nil +} diff --git a/publicapi/pagination_test.go b/publicapi/pagination_test.go index 892b2bee7..ea509617e 100644 --- a/publicapi/pagination_test.go +++ b/publicapi/pagination_test.go @@ -2,19 +2,37 @@ package publicapi import ( "testing" + "time" "github.com/stretchr/testify/assert" + + "github.com/mikeydub/go-gallery/service/persist" ) func TestMain(t *testing.T) { t.Run("test cursor pagination", func(t *testing.T) { + t.Run("cursor encodes expected types", func(t *testing.T) { + _, err := pack( + time.Now(), + true, + 1, + int64(1), + uint64(1), + "id", + persist.DBID("id"), + []persist.DBID{"id0", "id1"}, + []persist.FeedEntityType{0, 1}, + ) + assert.NoError(t, err) + }) + t.Run("cursor pagination returns expected edges", func(t *testing.T) { t.Run("should return no edges if no edges", func(t *testing.T) { edges := []string{} first := 10 var last int - actual, _, err := pageFrom(edges, nil, identityCursor, nil, nil, &first, &last) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, nil, &first, &last) assert.NoError(t, err) assert.Equal(t, 0, len(actual)) }) @@ -23,7 +41,7 @@ func TestMain(t *testing.T) { edges := []string{"a", "b", "c", "d", "e"} expected := []string{"a", "b", "c", "d", "e"} - actual, _, err := pageFrom(edges, nil, identityCursor, nil, nil, nil, nil) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, nil, nil, nil) assert.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -32,7 +50,7 @@ func TestMain(t *testing.T) { edges := []string{"a", "b", "c", "d", "e"} first := 0 - actual, _, err := pageFrom(edges, nil, identityCursor, nil, nil, &first, nil) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, nil, &first, nil) assert.NoError(t, err) assert.Equal(t, 0, len(actual)) }) @@ -41,7 +59,7 @@ func TestMain(t *testing.T) { edges := []string{"a", "b", "c", "d", "e"} last := 0 - actual, _, err := pageFrom(edges, nil, identityCursor, nil, nil, nil, &last) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, nil, nil, &last) assert.NoError(t, err) assert.Equal(t, 0, len(actual)) }) @@ -51,7 +69,7 @@ func TestMain(t *testing.T) { first := 2 expected := []string{"a", "b"} - actual, _, err := pageFrom(edges, nil, identityCursor, nil, nil, &first, nil) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, nil, &first, nil) assert.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -61,7 +79,7 @@ func TestMain(t *testing.T) { last := 2 expected := []string{"d", "e"} - actual, _, err := pageFrom(edges, nil, identityCursor, nil, nil, nil, &last) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, nil, nil, &last) assert.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -72,7 +90,7 @@ func TestMain(t *testing.T) { after := "b" expected := []string{"c", "d"} - actual, _, err := pageFrom(edges, nil, identityCursor, nil, &after, &first, nil) + actual, _, err := pageFrom(edges, nil, stubbedCursor, nil, &after, &first, nil) assert.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -83,7 +101,7 @@ func TestMain(t *testing.T) { before := "d" expected := []string{"b", "c"} - actual, _, err := pageFrom(edges, nil, identityCursor, &before, nil, nil, &last) + actual, _, err := pageFrom(edges, nil, stubbedCursor, &before, nil, nil, &last) assert.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -95,7 +113,7 @@ func TestMain(t *testing.T) { after := "a" expected := []string{"c", "d"} - actual, _, err := pageFrom(edges, nil, identityCursor, &before, &after, nil, &last) + actual, _, err := pageFrom(edges, nil, stubbedCursor, &before, &after, nil, &last) assert.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -106,7 +124,7 @@ func TestMain(t *testing.T) { edges := []string{} first := 10 - _, pageInfo, err := pageFrom(edges, nil, identityCursor, nil, nil, &first, nil) + _, pageInfo, err := pageFrom(edges, nil, stubbedCursor, nil, nil, &first, nil) assert.NoError(t, err) assert.Equal(t, 0, pageInfo.Size) assert.Equal(t, false, pageInfo.HasPreviousPage) @@ -119,7 +137,7 @@ func TestMain(t *testing.T) { edges := []string{"a", "b", "c", "d", "e"} first := 10 - _, pageInfo, err := pageFrom(edges, nil, identityCursor, nil, nil, &first, nil) + _, pageInfo, err := pageFrom(edges, nil, stubbedCursor, nil, nil, &first, nil) assert.NoError(t, err) assert.Equal(t, 5, pageInfo.Size) assert.Equal(t, false, pageInfo.HasPreviousPage) @@ -132,7 +150,7 @@ func TestMain(t *testing.T) { edges := []string{"a", "b", "c", "d", "e"} first := 2 - _, pageInfo, err := pageFrom(edges, nil, identityCursor, nil, nil, &first, nil) + _, pageInfo, err := pageFrom(edges, nil, stubbedCursor, nil, nil, &first, nil) assert.NoError(t, err) assert.Equal(t, 2, pageInfo.Size) assert.Equal(t, false, pageInfo.HasPreviousPage) @@ -145,7 +163,7 @@ func TestMain(t *testing.T) { edges := []string{"a", "b", "c", "d", "e"} last := 2 - _, pageInfo, err := pageFrom(edges, nil, identityCursor, nil, nil, nil, &last) + _, pageInfo, err := pageFrom(edges, nil, stubbedCursor, nil, nil, nil, &last) assert.NoError(t, err) assert.Equal(t, 2, pageInfo.Size) assert.Equal(t, true, pageInfo.HasPreviousPage) @@ -158,7 +176,7 @@ func TestMain(t *testing.T) { t.Run("test keyset pagination", func(t *testing.T) { t.Run("should exclude extra edges", func(t *testing.T) { - p := newStubKeysetPaginator([]any{"a", "b", "c", "d", "e", "extra"}) + p := newStubPaginator([]any{"a", "b", "c", "d", "e", "extra"}) expected := []string{"a", "b", "c", "d", "e"} first := 5 @@ -174,7 +192,7 @@ func TestMain(t *testing.T) { }) t.Run("should return expected page info when paging forward", func(t *testing.T) { - p := newStubKeysetPaginator([]any{"a", "b", "c", "d", "e", "extra"}) + p := newStubPaginator([]any{"a", "b", "c", "d", "e", "extra"}) first := 5 _, pageInfo, err := p.paginate(nil, nil, &first, nil) @@ -187,7 +205,7 @@ func TestMain(t *testing.T) { }) t.Run("should return expected edge order when paging backwards", func(t *testing.T) { - p := newStubKeysetPaginator([]any{"e", "d", "c", "b", "a", "extra"}) + p := newStubPaginator([]any{"e", "d", "c", "b", "a", "extra"}) expected := []string{"a", "b", "c", "d", "e"} last := 5 @@ -203,7 +221,7 @@ func TestMain(t *testing.T) { }) t.Run("should return expected page info when paging backwards", func(t *testing.T) { - p := newStubKeysetPaginator([]any{"e", "d", "c", "b", "a", "extra"}) + p := newStubPaginator([]any{"e", "d", "c", "b", "a", "extra"}) last := 5 _, pageInfo, err := p.paginate(nil, nil, nil, &last) @@ -217,20 +235,21 @@ func TestMain(t *testing.T) { }) } -func identityCursor(s string) (string, error) { - return s, nil -} +type stubCursor struct{ ID string } + +func (p stubCursor) Pack() (string, error) { return p.ID, nil } +func (p stubCursor) Unpack(s string) error { panic("not implemented") } -func newStubKeysetPaginator(ret []any) keysetPaginator { +var stubbedCursor = func(node any) (c cursorer, err error) { return stubCursor{ID: node.(string)}, nil } + +func newStubPaginator(ret []any) keysetPaginator { var p keysetPaginator p.QueryFunc = func(int32, bool) ([]any, error) { return ret, nil } - p.CursorFunc = func(a any) (string, error) { - return a.(string), nil - } + p.Cursorable = stubbedCursor return p } diff --git a/publicapi/publicapi.go b/publicapi/publicapi.go index 8eb101259..02a7d3676 100644 --- a/publicapi/publicapi.go +++ b/publicapi/publicapi.go @@ -150,8 +150,8 @@ func newDBIDCache(cache *redis.Cache, key string, ttl time.Duration, f func(cont return nil, err } - var p positionPaginator - b, err := p.encodeCursor(0, ids) + cur := positionCursor{CurrentPosition: 0, IDs: ids} + b, err := cur.Pack() return []byte(b), err }, }, @@ -163,7 +163,7 @@ func (d dbidCache) Load(ctx context.Context) ([]persist.DBID, error) { if err != nil { return nil, err } - var p positionPaginator - _, ids, err := p.decodeCursor(string(b)) - return ids, err + var cur positionCursor + err = cur.Unpack(string(b)) + return cur.IDs, err } diff --git a/publicapi/user.go b/publicapi/user.go index 9dfb656bc..1cf757c5e 100644 --- a/publicapi/user.go +++ b/publicapi/user.go @@ -808,9 +808,9 @@ func (api UserAPI) SharedCommunities(ctx context.Context, userID persist.DBID, b return int(total), err } - cursorFunc := func(i any) (bool, bool, int, persist.DBID, error) { + cursorFunc := func(i any) (bool, bool, int64, persist.DBID, error) { if row, ok := i.(db.GetSharedContractsBatchPaginateRow); ok { - return row.DisplayedByUserA, row.DisplayedByUserB, int(row.OwnedCount), row.ID, nil + return row.DisplayedByUserA, row.DisplayedByUserB, int64(row.OwnedCount), row.ID, nil } return false, false, 0, "", fmt.Errorf("node is not a db.GetSharedContractsBatchPaginateRow") } @@ -1227,16 +1227,16 @@ func (api UserAPI) RecommendUsers(ctx context.Context, before, after *string, fi return nil, PageInfo{}, err } - paginator := positionPaginator{} - var userIDs []persist.DBID + var cursor positionCursor + var paginator positionPaginator // If we have a cursor, we can page through the original set of recommended users if before != nil { - if _, userIDs, err = paginator.decodeCursor(*before); err != nil { + if err = cursor.Unpack(*before); err != nil { return nil, PageInfo{}, err } } else if after != nil { - if _, userIDs, err = paginator.decodeCursor(*after); err != nil { + if err = cursor.Unpack(*after); err != nil { return nil, PageInfo{}, err } } else { @@ -1246,16 +1246,16 @@ func (api UserAPI) RecommendUsers(ctx context.Context, before, after *string, fi return nil, PageInfo{}, err } - userIDs, err = recommend.For(ctx).RecommendFromFollowingShuffled(ctx, curUserID, follows) + cursor.IDs, err = recommend.For(ctx).RecommendFromFollowingShuffled(ctx, curUserID, follows) if err != nil { return nil, PageInfo{}, err } } positionLookup := map[persist.DBID]int{} - idsAsString := make([]string, len(userIDs)) + idsAsString := make([]string, len(cursor.IDs)) - for i, id := range userIDs { + for i, id := range cursor.IDs { // Postgres uses 1-based indexing positionLookup[id] = i + 1 idsAsString[i] = id.String() @@ -1281,9 +1281,9 @@ func (api UserAPI) RecommendUsers(ctx context.Context, before, after *string, fi return results, nil } - paginator.CursorFunc = func(node any) (int, []persist.DBID, error) { + paginator.CursorFunc = func(node any) (int64, []persist.DBID, error) { if user, ok := node.(db.User); ok { - return positionLookup[user.ID], userIDs, nil + return int64(positionLookup[user.ID]), cursor.IDs, nil } return 0, nil, fmt.Errorf("node is not a db.User") }