Skip to content

Commit

Permalink
Cursor refactor (#1165)
Browse files Browse the repository at this point in the history
* Refactor cursors

* Refactor test

* Fix typo

* Add packable and unpackable

* Remove even more boilerplate

* Uncomment test
  • Loading branch information
jarrel-b authored Aug 29, 2023
1 parent cd0ec2a commit 3f7db2a
Show file tree
Hide file tree
Showing 7 changed files with 728 additions and 576 deletions.
17 changes: 7 additions & 10 deletions publicapi/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai
}

var collectionIDs []persist.DBID
cursor := cursors.NewPositionCursor()
var paginator positionPaginator

// If a cursor is provided, we can skip querying the cache
if before != nil {
if _, collectionIDs, err = paginator.decodeCursor(*before); err != nil {
if err = cursor.Unpack(*before); err != nil {
return nil, pageInfo, err
}
} else if after != nil {
if _, collectionIDs, err = paginator.decodeCursor(*after); err != nil {
if err = cursor.Unpack(*after); err != nil {
return nil, pageInfo, err
}
} else {
Expand All @@ -143,19 +144,15 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai
}

paginator.QueryFunc = func(params positionPagingParams) ([]any, error) {
cIDs, _ := util.Map(collectionIDs, func(id persist.DBID) (string, error) {
return id.String(), nil
})
cIDs := util.MapWithoutError(collectionIDs, func(id persist.DBID) string { return id.String() })
c, err := api.queries.GetVisibleCollectionsByIDsPaginate(ctx, db.GetVisibleCollectionsByIDsPaginateParams{
CollectionIds: cIDs,
CurBeforePos: params.CursorBeforePos,
CurAfterPos: params.CursorAfterPos,
PagingForward: params.PagingForward,
Limit: params.Limit,
})
a, _ := util.Map(c, func(c db.Collection) (any, error) {
return c, nil
})
a := util.MapWithoutError(c, func(c db.Collection) any { return c })
return a, err
}

Expand All @@ -164,8 +161,8 @@ func (api CollectionAPI) GetTopCollectionsForCommunity(ctx context.Context, chai
posLookup[id] = i + 1 // Postgres uses 1-based indexing
}

paginator.CursorFunc = func(node any) (int, []persist.DBID, error) {
return posLookup[node.(db.Collection).ID], collectionIDs, nil
paginator.CursorFunc = func(node any) (int64, []persist.DBID, error) {
return int64(posLookup[node.(db.Collection).ID]), collectionIDs, nil
}

// The collections are sorted by ascending rank so we need to switch the cursor positions
Expand Down
189 changes: 68 additions & 121 deletions publicapi/feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,22 +354,27 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri

var (
err error
cursor = cursors.NewFeedPositionCursor()
paginator feedPaginator
entityTypes []persist.FeedEntityType
entityIDs []persist.DBID
entityIDToPos = make(map[persist.DBID]int)
)

hasCursors := before != nil || after != nil

now := time.Now()

// Include posts for admins always during the soft launch
if !includePosts {
includePosts = shouldShowPosts(ctx)
}

if !hasCursors {
if before != nil {
if err = cursor.Unpack(*before); err != nil {
return nil, PageInfo{}, err
}
} else if after != nil {
if err = cursor.Unpack(*after); err != nil {
return nil, PageInfo{}, err
}
} else {
calcFunc := func(ctx context.Context) ([]persist.FeedEntityType, []persist.DBID, error) {
trendData, err := fetchFeedEntityScores(ctx, api.queries, feedParams{
IncludePosts: includePosts,
Expand All @@ -384,8 +389,8 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri
return timeFactor(e.CreatedAt, now) * engagementFactor(int(e.Interactions))
})

entityTypes = make([]persist.FeedEntityType, len(scored))
entityIDs = make([]persist.DBID, len(scored))
entityTypes := make([]persist.FeedEntityType, len(scored))
entityIDs := make([]persist.DBID, len(scored))

for i, e := range scored {
idx := len(scored) - i - 1
Expand All @@ -398,45 +403,41 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri

l := newFeedCache(api.cache, includePosts, calcFunc)

entityTypes, entityIDs, err = l.Load(ctx)
cursor.EntityTypes, cursor.EntityIDs, err = l.Load(ctx)
if err != nil {
return nil, PageInfo{}, err
}
}

queryFunc := func(params feedPagingParams) ([]any, error) {
if hasCursors {
entityTypes = params.EntityTypes
entityIDs = params.EntityIDs
}
for i, id := range entityIDs {
for i, id := range cursor.EntityIDs {
entityIDToPos[id] = i
}

// Filter slices in place
if !includePosts {
idx := 0
for i := range entityTypes {
if entityTypes[i] != persist.PostTypeTag {
entityTypes[idx] = entityTypes[i]
entityIDs[idx] = entityIDs[i]
for i := range cursor.EntityTypes {
if cursor.EntityTypes[i] != persist.PostTypeTag {
cursor.EntityTypes[idx] = cursor.EntityTypes[i]
cursor.EntityIDs[idx] = cursor.EntityIDs[i]
idx++
}
}
entityTypes = entityTypes[:idx]
entityIDs = entityIDs[:idx]
cursor.EntityTypes = cursor.EntityTypes[:idx]
cursor.EntityIDs = cursor.EntityIDs[:idx]
}

return loadFeedEntities(ctx, api.loaders, entityTypes, entityIDs)
return loadFeedEntities(ctx, api.loaders, cursor.EntityTypes, cursor.EntityIDs)
}

cursorFunc := func(node any) (int, []persist.FeedEntityType, []persist.DBID, error) {
cursorFunc := func(node any) (int64, []persist.FeedEntityType, []persist.DBID, error) {
_, id, err := feedCursor(node)
pos, ok := entityIDToPos[id]
if !ok {
panic(fmt.Sprintf("could not find position for id=%s", id))
}
return pos, entityTypes, entityIDs, err
return int64(pos), cursor.EntityTypes, cursor.EntityIDs, err
}

paginator.QueryFunc = queryFunc
Expand All @@ -459,21 +460,26 @@ func (api FeedAPI) CuratedFeed(ctx context.Context, before, after *string, first

var (
paginator feedPaginator
entityTypes []persist.FeedEntityType
entityIDs []persist.DBID
cursor = cursors.NewFeedPositionCursor()
entityIDToPos = make(map[persist.DBID]int)
)

hasCursors := before != nil || after != nil

now := time.Now()

// Include posts for admins always during the soft launch
if !includePosts {
includePosts = shouldShowPosts(ctx)
}

if !hasCursors {
if before != nil {
if err := cursor.Unpack(*before); err != nil {
return nil, PageInfo{}, err
}
} else if after != nil {
if err := cursor.Unpack(*after); err != nil {
return nil, PageInfo{}, err
}
} else {
trendData, err := fetchFeedEntityScores(ctx, api.queries, feedParams{
IncludePosts: includePosts,
IncludeEvents: !includePosts,
Expand Down Expand Up @@ -540,49 +546,45 @@ func (api FeedAPI) CuratedFeed(ctx context.Context, before, after *string, first

recommend.Shuffle(interleaved, 8)

entityTypes = make([]persist.FeedEntityType, len(interleaved))
entityIDs = make([]persist.DBID, len(interleaved))
cursor.EntityTypes = make([]persist.FeedEntityType, len(interleaved))
cursor.EntityIDs = make([]persist.DBID, len(interleaved))

for i, e := range interleaved {
idx := len(interleaved) - i - 1
entityTypes[idx] = persist.FeedEntityType(e.FeedEntityType)
entityIDs[idx] = e.ID
cursor.EntityTypes[idx] = persist.FeedEntityType(e.FeedEntityType)
cursor.EntityIDs[idx] = e.ID
}
}

queryFunc := func(params feedPagingParams) ([]any, error) {
if hasCursors {
entityTypes = params.EntityTypes
entityIDs = params.EntityIDs
}
for i, id := range entityIDs {
for i, id := range cursor.EntityIDs {
entityIDToPos[id] = i
}

// Filter slices in place
if !includePosts {
idx := 0
for i := range entityTypes {
if entityTypes[i] != persist.PostTypeTag {
entityTypes[idx] = entityTypes[i]
entityIDs[idx] = entityIDs[i]
for i := range cursor.EntityTypes {
if cursor.EntityTypes[i] != persist.PostTypeTag {
cursor.EntityTypes[idx] = cursor.EntityTypes[i]
cursor.EntityIDs[idx] = cursor.EntityIDs[i]
idx++
}
}
entityTypes = entityTypes[:idx]
entityIDs = entityIDs[:idx]
cursor.EntityTypes = cursor.EntityTypes[:idx]
cursor.EntityIDs = cursor.EntityIDs[:idx]
}

return loadFeedEntities(ctx, api.loaders, entityTypes, entityIDs)
return loadFeedEntities(ctx, api.loaders, cursor.EntityTypes, cursor.EntityIDs)
}

cursorFunc := func(node any) (int, []persist.FeedEntityType, []persist.DBID, error) {
cursorFunc := func(node any) (int64, []persist.FeedEntityType, []persist.DBID, error) {
_, id, err := feedCursor(node)
pos, ok := entityIDToPos[id]
if !ok {
panic(fmt.Sprintf("could not find position for id=%s", id))
}
return pos, entityTypes, entityIDs, err
return int64(pos), cursor.EntityTypes, cursor.EntityIDs, err
}

paginator.QueryFunc = queryFunc
Expand Down Expand Up @@ -835,58 +837,7 @@ type feedPagingParams struct {

type feedPaginator struct {
QueryFunc func(params feedPagingParams) ([]any, error)
CursorFunc func(node any) (pos int, feedEntityType []persist.FeedEntityType, ids []persist.DBID, err error)
}

func (p *feedPaginator) encodeCursor(pos int, typ []persist.FeedEntityType, ids []persist.DBID) (string, error) {
if len(typ) != len(ids) {
panic("type and ids must be the same length")
}
encoder := newCursorEncoder()
encoder.appendInt64(int64(pos))
encoder.appendInt64(int64(len(ids)))
for i := range typ {
encoder.appendInt64(int64(typ[i]))
encoder.appendDBID(ids[i])
}
return encoder.AsBase64(), nil
}

func (p *feedPaginator) decodeCursor(cursor string) (pos int, typs []persist.FeedEntityType, ids []persist.DBID, err error) {
decoder, err := newCursorDecoder(cursor)
if err != nil {
return 0, nil, nil, err
}

curPos, err := decoder.readInt64()
if err != nil {
return 0, nil, nil, err
}

totalItems, err := decoder.readInt64()
if err != nil {
return 0, nil, nil, err
}

typs = make([]persist.FeedEntityType, totalItems)
ids = make([]persist.DBID, totalItems)

for i := 0; i < int(totalItems); i++ {
typ, err := decoder.readInt64()
if err != nil {
return 0, nil, nil, err
}

id, err := decoder.readDBID()
if err != nil {
return 0, nil, nil, err
}

typs[i] = persist.FeedEntityType(typ)
ids[i] = id
}

return int(curPos), typs, ids, nil
CursorFunc func(node any) (pos int64, feedEntityType []persist.FeedEntityType, ids []persist.DBID, err error)
}

func (p *feedPaginator) paginate(before, after *string, first, last *int) ([]any, PageInfo, error) {
Expand All @@ -895,40 +846,33 @@ func (p *feedPaginator) paginate(before, after *string, first, last *int) ([]any
CurAfterPos: defaultCursorAfterPosition,
}

beforeCur := cursors.NewFeedPositionCursor()
afterCur := cursors.NewFeedPositionCursor()

if before != nil {
curBeforePos, typs, ids, err := p.decodeCursor(*before)
if err != nil {
if err := beforeCur.Unpack(*before); err != nil {
return nil, PageInfo{}, err
}
args.CurBeforePos = curBeforePos
args.EntityTypes = typs
args.EntityIDs = ids
args.CurBeforePos = int(beforeCur.CurrentPosition)
args.EntityTypes = beforeCur.EntityTypes
args.EntityIDs = beforeCur.EntityIDs
}

if after != nil {
curAfterPos, typs, ids, err := p.decodeCursor(*after)
if err != nil {
if err := afterCur.Unpack(*after); err != nil {
return nil, PageInfo{}, err
}
args.CurAfterPos = curAfterPos
args.EntityTypes = typs
args.EntityIDs = ids
args.CurAfterPos = int(afterCur.CurrentPosition)
args.EntityTypes = afterCur.EntityTypes
args.EntityIDs = afterCur.EntityIDs
}

results, err := p.QueryFunc(args)
if err != nil {
return nil, PageInfo{}, err
}

cursorFunc := func(node any) (string, error) {
pos, typs, ids, err := p.CursorFunc(node)
if err != nil {
return "", err
}
return p.encodeCursor(pos, typs, ids)
}

return pageFrom(results, nil, cursorFunc, before, after, first, last)
return pageFrom(results, nil, cursorables.NewFeedPositionCursorer(p.CursorFunc), before, after, first, last)
}

type feedCache struct {
Expand All @@ -951,8 +895,11 @@ func newFeedCache(cache *redis.Cache, includePosts bool, f func(context.Context)
if err != nil {
return nil, err
}
var p feedPaginator
b, err := p.encodeCursor(0, types, ids)
cur := cursors.NewFeedPositionCursor()
cur.CurrentPosition = 0
cur.EntityTypes = types
cur.EntityIDs = ids
b, err := cur.Pack()
return []byte(b), err
},
},
Expand All @@ -964,9 +911,9 @@ func (f feedCache) Load(ctx context.Context) ([]persist.FeedEntityType, []persis
if err != nil {
return nil, nil, err
}
var p feedPaginator
_, types, ids, err := p.decodeCursor(string(b))
return types, ids, err
cur := cursors.NewFeedPositionCursor()
err = cur.Unpack(string(b))
return cur.EntityTypes, cur.EntityIDs, err
}

func min(a, b int) int {
Expand Down
Loading

0 comments on commit 3f7db2a

Please sign in to comment.