From d61f60b4dd51b8dea3d9a345964eaac0b95a54b2 Mon Sep 17 00:00:00 2001 From: jarrel Date: Thu, 17 Aug 2023 15:02:06 -0700 Subject: [PATCH] Include only posts (#1157) --- db/gen/coredb/recommend.sql.go | 9 ++++-- db/queries/core/recommend.sql | 1 + publicapi/feed.go | 59 ++++++++++++++++++++++++++-------- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/db/gen/coredb/recommend.sql.go b/db/gen/coredb/recommend.sql.go index 023ba2307..b17530692 100644 --- a/db/gen/coredb/recommend.sql.go +++ b/db/gen/coredb/recommend.sql.go @@ -96,14 +96,15 @@ from feed_entity_scores f1 where f1.created_at > $1::timestamptz and ($2::bool or f1.actor_id != $3) and ($4::bool or f1.feed_entity_type != $5) - and not (f1.action = any($6::varchar[])) + and ($6::bool or f1.feed_entity_type != $7) + and not (f1.action = any($8::varchar[])) union select id, created_at, actor_id, action, contract_ids, interactions, feed_entity_type, last_updated from feed_entity_score_view f2 where created_at > (select last_updated from refreshed limit 1) and ($2::bool or f2.actor_id != $3) and ($4::bool or f2.feed_entity_type != $5) - and not (f2.action = any($6::varchar[])) + and not (f2.action = any($8::varchar[])) ` type GetFeedEntityScoresParams struct { @@ -112,6 +113,8 @@ type GetFeedEntityScoresParams struct { ViewerID persist.DBID `json:"viewer_id"` IncludePosts bool `json:"include_posts"` PostEntityType int32 `json:"post_entity_type"` + IncludeEvents bool `json:"include_events"` + FeedEntityType int32 `json:"feed_entity_type"` ExcludedFeedActions []string `json:"excluded_feed_actions"` } @@ -122,6 +125,8 @@ func (q *Queries) GetFeedEntityScores(ctx context.Context, arg GetFeedEntityScor arg.ViewerID, arg.IncludePosts, arg.PostEntityType, + arg.IncludeEvents, + arg.FeedEntityType, arg.ExcludedFeedActions, ) if err != nil { diff --git a/db/queries/core/recommend.sql b/db/queries/core/recommend.sql index f7e292de9..08bdd1190 100644 --- a/db/queries/core/recommend.sql +++ b/db/queries/core/recommend.sql @@ -80,6 +80,7 @@ from feed_entity_scores f1 where f1.created_at > @window_end::timestamptz and (@include_viewer::bool or f1.actor_id != @viewer_id) and (@include_posts::bool or f1.feed_entity_type != @post_entity_type) + and (@include_events::bool or f1.feed_entity_type != @feed_entity_type) and not (f1.action = any(@excluded_feed_actions::varchar[])) union select * diff --git a/publicapi/feed.go b/publicapi/feed.go index ed1954802..1ddab995e 100644 --- a/publicapi/feed.go +++ b/publicapi/feed.go @@ -309,6 +309,41 @@ func (api FeedAPI) GlobalFeed(ctx context.Context, before *string, after *string return paginator.paginate(before, after, first, last) } +type feedParams struct { + ExcludeUserID persist.DBID + IncludePosts bool + IncludeEvents bool + ExcludeActions []persist.Action + FetchFrom time.Duration +} + +func fetchFeedEntities(ctx context.Context, queries *db.Queries, p feedParams) ([]db.FeedEntityScore, error) { + var q db.GetFeedEntityScoresParams + + q.IncludeViewer = true + q.IncludePosts = true + q.IncludeEvents = true + q.WindowEnd = time.Now().Add(-p.FetchFrom) + q.ExcludedFeedActions = util.MapWithoutError(p.ExcludeActions, func(a persist.Action) string { return string(a) }) + + if !p.IncludePosts { + q.IncludePosts = false + q.PostEntityType = int32(persist.PostTypeTag) + } + + if !p.IncludeEvents { + q.IncludeEvents = false + q.FeedEntityType = int32(persist.FeedEventTypeTag) + } + + if p.ExcludeUserID != "" { + q.IncludeViewer = false + q.ViewerID = p.ExcludeUserID + } + + return queries.GetFeedEntityScores(ctx, q) +} + func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *string, first *int, last *int, includePosts bool) ([]any, PageInfo, error) { // Validate if err := validatePaginationParams(api.validator, first, last); err != nil { @@ -334,12 +369,11 @@ func (api FeedAPI) TrendingFeed(ctx context.Context, before *string, after *stri if !hasCursors { calcFunc := func(ctx context.Context) ([]persist.FeedEntityType, []persist.DBID, error) { - trendData, err := api.queries.GetFeedEntityScores(ctx, db.GetFeedEntityScoresParams{ - WindowEnd: time.Now().Add(-time.Duration(3 * 24 * time.Hour)), - PostEntityType: int32(persist.PostTypeTag), - ExcludedFeedActions: []string{string(persist.ActionUserCreated), string(persist.ActionUserFollowedUsers)}, - IncludeViewer: true, - IncludePosts: includePosts, + trendData, err := fetchFeedEntities(ctx, api.queries, feedParams{ + IncludePosts: includePosts, + IncludeEvents: true, + ExcludeActions: []persist.Action{persist.ActionUserCreated, persist.ActionUserFollowedUsers}, + FetchFrom: time.Duration(3 * 24 * time.Hour), }) if err != nil { return nil, nil, err @@ -438,13 +472,12 @@ func (api FeedAPI) CuratedFeed(ctx context.Context, before, after *string, first } if !hasCursors { - trendData, err := api.queries.GetFeedEntityScores(ctx, db.GetFeedEntityScoresParams{ - WindowEnd: time.Now().Add(-time.Duration(7 * 24 * time.Hour)), - PostEntityType: int32(persist.PostTypeTag), - ExcludedFeedActions: []string{string(persist.ActionUserCreated), string(persist.ActionUserFollowedUsers)}, - IncludeViewer: false, - ViewerID: userID, - IncludePosts: includePosts, + trendData, err := fetchFeedEntities(ctx, api.queries, feedParams{ + IncludePosts: includePosts, + IncludeEvents: !includePosts, + ExcludeUserID: userID, + ExcludeActions: []persist.Action{persist.ActionUserCreated, persist.ActionUserFollowedUsers}, + FetchFrom: time.Duration(7 * 24 * time.Hour), }) if err != nil { return nil, PageInfo{}, err