diff --git a/controller/stories/stories.go b/controller/stories/stories.go index aed024e..0d33768 100644 --- a/controller/stories/stories.go +++ b/controller/stories/stories.go @@ -54,6 +54,146 @@ func HandleList(w http.ResponseWriter, r *http.Request) error { return nil } +func HandleListPublished(w http.ResponseWriter, r *http.Request) error { + err := auth.CheckPermissions(r, storypermissiongroups.List()) + if err != nil { + logrus.Error(err) + return apierrors.ClientForbiddenError{ + Message: fmt.Sprintf("Error listing published stories: %v", err), + } + } + + // Get DB instance + db, err := database.GetDBFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + // Get group id from context + groupID, err := usergroups.GetGroupIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + stories, err := model.GetAllStoriesByStatus(db, groupID, model.Published) + if err != nil { + logrus.Error(err) + return err + } + + controller.EncodeJSONResponse(w, storyviews.ListFrom(stories)) + return nil +} + +func HandleListPending(w http.ResponseWriter, r *http.Request) error { + err := auth.CheckPermissions(r, storypermissiongroups.Moderate()) + if err != nil { + logrus.Error(err) + return apierrors.ClientForbiddenError{ + Message: fmt.Sprintf("Error listing stories to be reviewed: %v", err), + } + } + + // Get DB instance + db, err := database.GetDBFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + // Get group id from context + groupID, err := usergroups.GetGroupIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + stories, err := model.GetAllStoriesByStatus(db, groupID, model.Pending) + if err != nil { + logrus.Error(err) + return err + } + + controller.EncodeJSONResponse(w, storyviews.ListFrom(stories)) + return nil +} + +func HandleListDraft(w http.ResponseWriter, r *http.Request) error { + err := auth.CheckPermissions(r, storypermissiongroups.List()) + if err != nil { + logrus.Error(err) + return apierrors.ClientForbiddenError{ + Message: fmt.Sprintf("Error listing drafts: %v", err), + } + } + userID, err := auth.GetUserIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + // Get DB instance + db, err := database.GetDBFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + // Get group id from context + groupID, err := usergroups.GetGroupIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + stories, err := model.GetAllAuthorStoriesByStatus(db, groupID, userID, model.Draft) + if err != nil { + logrus.Error(err) + return err + } + + controller.EncodeJSONResponse(w, storyviews.ListFrom(stories)) + return nil +} + +func HandleListRejected(w http.ResponseWriter, r *http.Request) error { + err := auth.CheckPermissions(r, storypermissiongroups.List()) + if err != nil { + logrus.Error(err) + return apierrors.ClientForbiddenError{ + Message: fmt.Sprintf("Error listing rejected stories: %v", err), + } + } + userID, err := auth.GetUserIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + // Get DB instance + db, err := database.GetDBFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + // Get group id from context + groupID, err := usergroups.GetGroupIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + stories, err := model.GetAllAuthorStoriesByStatus(db, groupID, userID, model.Rejected) + if err != nil { + logrus.Error(err) + return err + } + + controller.EncodeJSONResponse(w, storyviews.ListFrom(stories)) + return nil +} + func HandleRead(w http.ResponseWriter, r *http.Request) error { storyIDStr := chi.URLParam(r, "storyID") storyID, err := strconv.Atoi(storyIDStr) diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 88987e3..44317d8 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -28,7 +28,10 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { // Skip auth in development mode if conf.Environment == envutils.ENV_DEVELOPMENT { return func(next http.Handler) http.Handler { - return next + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = injectUserIDToContext(r, 1) + next.ServeHTTP(w, r) + }) } } diff --git a/internal/permissiongroups/stories/stories.go b/internal/permissiongroups/stories/stories.go index 06ddcbd..039f753 100644 --- a/internal/permissiongroups/stories/stories.go +++ b/internal/permissiongroups/stories/stories.go @@ -20,6 +20,11 @@ func Read() permissions.PermissionGroup { GetRolePermission(userpermissions.CanReadStories) } +func Moderate() permissions.PermissionGroup { + return userpermissions. + GetRolePermission(userpermissions.CanModerateStories) +} + func Update(storyID uint) permissions.PermissionGroup { return permissions.AnyOf{ Groups: []permissions.PermissionGroup{ diff --git a/internal/permissions/users/permissions.go b/internal/permissions/users/permissions.go index bc0a2f7..0aa325b 100644 --- a/internal/permissions/users/permissions.go +++ b/internal/permissions/users/permissions.go @@ -13,8 +13,9 @@ const ( CanUpdateGroups Permission = "can_update_groups" CanDeleteGroups Permission = "can_delete_groups" - CanCreateStories Permission = "can_create_stories" - CanReadStories Permission = "can_read_stories" - CanUpdateStories Permission = "can_update_stories" - CanDeleteStories Permission = "can_delete_stories" + CanCreateStories Permission = "can_create_stories" + CanReadStories Permission = "can_read_stories" + CanUpdateStories Permission = "can_update_stories" + CanDeleteStories Permission = "can_delete_stories" + CanModerateStories Permission = "can_moderate_stories" ) diff --git a/internal/permissions/users/users.go b/internal/permissions/users/users.go index a6e8b8b..e330ced 100644 --- a/internal/permissions/users/users.go +++ b/internal/permissions/users/users.go @@ -23,6 +23,7 @@ func GetRolePermission(p Permission) *RolePermission { case // Additional permissions for moderators and administrators CanUpdateStories, + CanModerateStories, CanDeleteStories: return &RolePermission{ Permission: p, diff --git a/internal/router/router.go b/internal/router/router.go index 9c70a40..5e9071d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -55,6 +55,10 @@ func Setup(config *config.Config, injectMiddleWares []func(http.Handler) http.Ha r.Use(usergroups.InjectUserGroupIntoContext) r.Route("/stories", func(r chi.Router) { r.Get("/", handleAPIError(stories.HandleList)) + r.Get("/draft", handleAPIError(stories.HandleListDraft)) + r.Get("/pending", handleAPIError(stories.HandleListPending)) + r.Get("/published", handleAPIError(stories.HandleListPublished)) + r.Get("/rejected", handleAPIError(stories.HandleListRejected)) r.Get("/{storyID}", handleAPIError(stories.HandleRead)) r.Put("/{storyID}", handleAPIError(stories.HandleUpdate)) r.Delete("/{storyID}", handleAPIError(stories.HandleDelete)) @@ -66,7 +70,7 @@ func Setup(config *config.Config, injectMiddleWares []func(http.Handler) http.Ha r.Get("/{userID}", handleAPIError(users.HandleRead)) r.Delete("/{userID}", handleAPIError(users.HandleDelete)) r.Post("/", handleAPIError(users.HandleCreate)) - r.Post("/batch", handleAPIError(usergroupscontroller.HandleBatchCreate)) + r.Put("/batch", handleAPIError(usergroupscontroller.HandleBatchCreate)) }) }) diff --git a/migrations/20240320000000-add_status.sql b/migrations/20240320000000-add_status.sql new file mode 100644 index 0000000..8cf8060 --- /dev/null +++ b/migrations/20240320000000-add_status.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE stories + ADD COLUMN status INT; + +-- +migrate Down + +ALTER TABLE stories + DROP COLUMN status; \ No newline at end of file diff --git a/migrations/20240320000001-add_status_message.sql b/migrations/20240320000001-add_status_message.sql new file mode 100644 index 0000000..821694e --- /dev/null +++ b/migrations/20240320000001-add_status_message.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE stories + ADD COLUMN status_message TEXT; + +-- +migrate Down + +ALTER TABLE stories + DROP COLUMN status_message; \ No newline at end of file diff --git a/model/stories.go b/model/stories.go index 1f1dea6..bbeee82 100644 --- a/model/stories.go +++ b/model/stories.go @@ -2,19 +2,31 @@ package model import ( "github.com/source-academy/stories-backend/internal/database" + groupenums "github.com/source-academy/stories-backend/internal/enums/groups" "gorm.io/gorm" "gorm.io/gorm/clause" ) +type StoryStatus int + +const ( + Draft StoryStatus = iota + Pending + Rejected + Published +) + type Story struct { gorm.Model - AuthorID uint - Author User - GroupID *uint // null means this is a public story - Group Group - Title string - Content string - PinOrder *int // nil if not pinned + AuthorID uint + Author User + GroupID *uint // null means this is a public story + Group Group + Title string + Content string + PinOrder *int // nil if not pinned + Status StoryStatus + StatusMessage *string } // Passing nil to omit the filtering and get all stories @@ -35,6 +47,71 @@ func GetAllStoriesInGroup(db *gorm.DB, groupID *uint) ([]Story, error) { return stories, nil } +func GetAllPublishedStories(db *gorm.DB, groupID *uint) ([]Story, error) { + var stories []Story + err := db. + Where("status = ?", int(Published)). + Where("group_id = ?", groupID). + Preload(clause.Associations). + // TODO: Abstract out the sorting logic + Order("pin_order ASC NULLS LAST, title ASC, content ASC"). + Find(&stories). + Error + if err != nil { + return stories, database.HandleDBError(err, "story") + } + return stories, nil +} + +func GetAllPendingStories(db *gorm.DB, groupID *uint) ([]Story, error) { + var stories []Story + err := db. + Where("status = ?", int(Pending)). + Where("group_id = ?", groupID). + Preload(clause.Associations). + // TODO: Abstract out the sorting logic + Order("pin_order ASC NULLS LAST, title ASC, content ASC"). + Find(&stories). + Error + if err != nil { + return stories, database.HandleDBError(err, "story") + } + return stories, nil +} + +func GetAllStoriesByStatus(db *gorm.DB, groupID *uint, status StoryStatus) ([]Story, error) { + var stories []Story + err := db. + Where("status = ?", int(status)). + Where("group_id = ?", groupID). + Preload(clause.Associations). + // TODO: Abstract out the sorting logic + Order("pin_order ASC NULLS LAST, title ASC, content ASC"). + Find(&stories). + Error + if err != nil { + return stories, database.HandleDBError(err, "story") + } + return stories, nil +} + +func GetAllAuthorStoriesByStatus(db *gorm.DB, groupID *uint, userID *int, status StoryStatus) ([]Story, error) { + var stories []Story + err := db. + Where("status = ?", int(status)). + Where("group_id = ?", groupID). + Where("author_id = ?", userID). + Preload(clause.Associations). + // TODO: Abstract out the sorting logic + Order("pin_order ASC NULLS LAST, title ASC, content ASC"). + Find(&stories). + Error + if err != nil { + return stories, database.HandleDBError(err, "story") + } + return stories, nil +} + func GetStoryByID(db *gorm.DB, id int) (Story, error) { var story Story err := db. @@ -58,8 +135,21 @@ func (s *Story) create(tx *gorm.DB) *gorm.DB { } func CreateStory(db *gorm.DB, story *Story) error { + // Check author's role + role, _ := GetUserRoleByID(db, story.AuthorID) + // Based on the TestCreateStory, "can create without group" seems to be the desired behaviour + // No group means no userGroup, which means no role, so an error shouldn't be thrown + // Set story status based on author's role + if !groupenums.IsRoleGreaterThan(role, groupenums.RoleStandard) { + story.Status = Draft + } else { + story.Status = Published + } err := db.Transaction(func(tx *gorm.DB) error { - return story.create(tx).Error + if err := tx.Create(story).Error; err != nil { + return err // Return the error directly + } + return nil }) if err != nil { return database.HandleDBError(err, "story") diff --git a/model/usergroups.go b/model/usergroups.go index 3bab68e..d40a9ea 100644 --- a/model/usergroups.go +++ b/model/usergroups.go @@ -31,6 +31,20 @@ func GetUserGroupByID(db *gorm.DB, userID uint, groupID uint) (UserGroup, error) return userGroup, nil } +func GetUserRoleByID(db *gorm.DB, userID uint) (groupenums.Role, error) { + var userGroup UserGroup + + err := db.Model(&userGroup). + Where(UserGroup{UserID: userID}). + First(&userGroup).Error + + if err != nil { + return userGroup.Role, database.HandleDBError(err, "userRole") + } + + return userGroup.Role, nil +} + func CreateUserGroup(db *gorm.DB, userGroup *UserGroup) error { err := db.Create(userGroup).Error if err != nil { diff --git a/params/stories/create.go b/params/stories/create.go index 3e8bf80..cc101d5 100644 --- a/params/stories/create.go +++ b/params/stories/create.go @@ -1,6 +1,7 @@ package storyparams import ( + "fmt" "github.com/source-academy/stories-backend/model" ) @@ -13,6 +14,18 @@ type Create struct { // TODO: Add some validation func (params *Create) Validate() error { + if params.AuthorID == 0 { + return fmt.Errorf("authorId is required and must be non-zero") + } + if params.Title == "" { + return fmt.Errorf("title is required and cannot be empty") + } + if params.Content == "" { + return fmt.Errorf("content is required and cannot be empty") + } + if params.PinOrder != nil && *params.PinOrder < 0 { + return fmt.Errorf("pinOrder, if set, must be non-negative") + } return nil } diff --git a/params/stories/create_test.go b/params/stories/create_test.go index f100c4d..8975470 100644 --- a/params/stories/create_test.go +++ b/params/stories/create_test.go @@ -7,7 +7,27 @@ import ( ) func TestValidate(t *testing.T) { - t.Run("should do nothing for now", func(t *testing.T) {}) + negativePinOrder := -1 + tests := []struct { + name string + params Create + wantErr bool + }{ + {"valid input", Create{AuthorID: 1, Title: "Test Title", Content: "Test Content"}, false}, + {"missing authorId", Create{Title: "Test Title", Content: "Test Content"}, true}, + {"empty title", Create{AuthorID: 1, Content: "Test Content"}, true}, + {"empty content", Create{AuthorID: 1, Title: "Test Title"}, true}, + {"negative pinOrder", Create{AuthorID: 1, Title: "Test Title", Content: "Test Content", PinOrder: &negativePinOrder}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.params.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } } func TestToModel(t *testing.T) { diff --git a/params/stories/update.go b/params/stories/update.go index 8eb14fe..be49751 100644 --- a/params/stories/update.go +++ b/params/stories/update.go @@ -5,9 +5,11 @@ import ( ) type Update struct { - Title string `json:"title"` - Content string `json:"content"` - PinOrder *int `json:"pinOrder"` + Title string `json:"title"` + Content string `json:"content"` + PinOrder *int `json:"pinOrder"` + Status int `json:"status"` + StatusMessage string `json:"statusMessage"` } func (params *Update) Validate() error { @@ -18,8 +20,10 @@ func (params *Update) Validate() error { func (params *Update) ToModel() *model.Story { return &model.Story{ - Title: params.Title, - Content: params.Content, - PinOrder: params.PinOrder, + Title: params.Title, + Content: params.Content, + PinOrder: params.PinOrder, + Status: model.StoryStatus(params.Status), + StatusMessage: ¶ms.StatusMessage, } }