From 13312eaa0cd0f4730660e024a66dc08685a7294b Mon Sep 17 00:00:00 2001 From: Shaddoll Date: Wed, 30 Oct 2024 17:57:25 -0700 Subject: [PATCH] Improve coverage for list workflow handlers (#6451) --- service/frontend/api/handler.go | 464 ------------------ .../frontend/api/list_workflow_handlers.go | 367 ++++++++++++++ service/frontend/api/request_validator.go | 127 ++++- .../frontend/api/request_validator_mock.go | 70 +++ .../frontend/api/request_validator_test.go | 425 +++++++++++++++- 5 files changed, 987 insertions(+), 466 deletions(-) create mode 100644 service/frontend/api/list_workflow_handlers.go diff --git a/service/frontend/api/handler.go b/service/frontend/api/handler.go index 908f6282dc8..577f2a6c802 100644 --- a/service/frontend/api/handler.go +++ b/service/frontend/api/handler.go @@ -2703,371 +2703,6 @@ func (wh *WorkflowHandler) RequestCancelWorkflowExecution( return nil } -// ListOpenWorkflowExecutions - retrieves info for open workflow executions in a domain -func (wh *WorkflowHandler) ListOpenWorkflowExecutions( - ctx context.Context, - listRequest *types.ListOpenWorkflowExecutionsRequest, -) (resp *types.ListOpenWorkflowExecutionsResponse, retError error) { - if wh.isShuttingDown() { - return nil, validate.ErrShuttingDown - } - - if listRequest == nil { - return nil, validate.ErrRequestNotSet - } - - if listRequest.GetDomain() == "" { - return nil, validate.ErrDomainNotSet - } - - if listRequest.StartTimeFilter == nil { - return nil, &types.BadRequestError{Message: "StartTimeFilter is required"} - } - - if listRequest.StartTimeFilter.EarliestTime == nil { - return nil, &types.BadRequestError{Message: "EarliestTime in StartTimeFilter is required"} - } - - if listRequest.StartTimeFilter.LatestTime == nil { - return nil, &types.BadRequestError{Message: "LatestTime in StartTimeFilter is required"} - } - - if listRequest.StartTimeFilter.GetEarliestTime() > listRequest.StartTimeFilter.GetLatestTime() { - return nil, &types.BadRequestError{Message: "EarliestTime in StartTimeFilter should not be larger than LatestTime"} - } - - if listRequest.ExecutionFilter != nil && listRequest.TypeFilter != nil { - return nil, &types.BadRequestError{ - Message: "Only one of ExecutionFilter or TypeFilter is allowed"} - } - - if listRequest.GetMaximumPageSize() <= 0 { - listRequest.MaximumPageSize = int32(wh.config.VisibilityMaxPageSize(listRequest.GetDomain())) - } - - if wh.isListRequestPageSizeTooLarge(listRequest.GetMaximumPageSize(), listRequest.GetDomain()) { - return nil, &types.BadRequestError{ - Message: fmt.Sprintf("Pagesize is larger than allow %d", wh.config.ESIndexMaxResultWindow())} - } - - domain := listRequest.GetDomain() - domainID, err := wh.GetDomainCache().GetDomainID(domain) - if err != nil { - return nil, err - } - - baseReq := persistence.ListWorkflowExecutionsRequest{ - DomainUUID: domainID, - Domain: domain, - PageSize: int(listRequest.GetMaximumPageSize()), - NextPageToken: listRequest.NextPageToken, - EarliestTime: listRequest.StartTimeFilter.GetEarliestTime(), - LatestTime: listRequest.StartTimeFilter.GetLatestTime(), - } - - var persistenceResp *persistence.ListWorkflowExecutionsResponse - if listRequest.ExecutionFilter != nil { - if wh.config.DisableListVisibilityByFilter(domain) { - err = validate.ErrNoPermission - } else { - persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutionsByWorkflowID( - ctx, - &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ - ListWorkflowExecutionsRequest: baseReq, - WorkflowID: listRequest.ExecutionFilter.GetWorkflowID(), - }) - } - wh.GetLogger().Debug("List open workflow with filter", - tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByID) - } else if listRequest.TypeFilter != nil { - if wh.config.DisableListVisibilityByFilter(domain) { - err = validate.ErrNoPermission - } else { - persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutionsByType( - ctx, - &persistence.ListWorkflowExecutionsByTypeRequest{ - ListWorkflowExecutionsRequest: baseReq, - WorkflowTypeName: listRequest.TypeFilter.GetName(), - }, - ) - } - wh.GetLogger().Debug("List open workflow with filter", - tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByType) - } else { - persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutions(ctx, &baseReq) - } - - if err != nil { - return nil, err - } - - resp = &types.ListOpenWorkflowExecutionsResponse{} - resp.Executions = persistenceResp.Executions - resp.NextPageToken = persistenceResp.NextPageToken - return resp, nil -} - -// ListArchivedWorkflowExecutions - retrieves archived info for closed workflow executions in a domain -func (wh *WorkflowHandler) ListArchivedWorkflowExecutions( - ctx context.Context, - listRequest *types.ListArchivedWorkflowExecutionsRequest, -) (resp *types.ListArchivedWorkflowExecutionsResponse, retError error) { - if wh.isShuttingDown() { - return nil, validate.ErrShuttingDown - } - - if listRequest == nil { - return nil, validate.ErrRequestNotSet - } - - if listRequest.GetDomain() == "" { - return nil, validate.ErrDomainNotSet - } - - if listRequest.GetPageSize() <= 0 { - listRequest.PageSize = int32(wh.config.VisibilityMaxPageSize(listRequest.GetDomain())) - } - - maxPageSize := wh.config.VisibilityArchivalQueryMaxPageSize() - if int(listRequest.GetPageSize()) > maxPageSize { - return nil, &types.BadRequestError{ - Message: fmt.Sprintf("Pagesize is larger than allowed %d", maxPageSize)} - } - - if !wh.GetArchivalMetadata().GetVisibilityConfig().ClusterConfiguredForArchival() { - return nil, &types.BadRequestError{Message: "Cluster is not configured for visibility archival"} - } - - if !wh.GetArchivalMetadata().GetVisibilityConfig().ReadEnabled() { - return nil, &types.BadRequestError{Message: "Cluster is not configured for reading archived visibility records"} - } - - entry, err := wh.GetDomainCache().GetDomain(listRequest.GetDomain()) - if err != nil { - return nil, err - } - - if entry.GetConfig().VisibilityArchivalStatus != types.ArchivalStatusEnabled { - return nil, &types.BadRequestError{Message: "Domain is not configured for visibility archival"} - } - - URI, err := archiver.NewURI(entry.GetConfig().VisibilityArchivalURI) - if err != nil { - return nil, err - } - - visibilityArchiver, err := wh.GetArchiverProvider().GetVisibilityArchiver(URI.Scheme(), service.Frontend) - if err != nil { - return nil, err - } - - archiverRequest := &archiver.QueryVisibilityRequest{ - DomainID: entry.GetInfo().ID, - PageSize: int(listRequest.GetPageSize()), - NextPageToken: listRequest.NextPageToken, - Query: listRequest.GetQuery(), - } - - archiverResponse, err := visibilityArchiver.Query(ctx, URI, archiverRequest) - if err != nil { - return nil, err - } - - // special handling of ExecutionTime for cron or retry - for _, execution := range archiverResponse.Executions { - if execution.GetExecutionTime() == 0 { - execution.ExecutionTime = common.Int64Ptr(execution.GetStartTime()) - } - } - - return &types.ListArchivedWorkflowExecutionsResponse{ - Executions: archiverResponse.Executions, - NextPageToken: archiverResponse.NextPageToken, - }, nil -} - -// ListClosedWorkflowExecutions - retrieves info for closed workflow executions in a domain -func (wh *WorkflowHandler) ListClosedWorkflowExecutions( - ctx context.Context, - listRequest *types.ListClosedWorkflowExecutionsRequest, -) (resp *types.ListClosedWorkflowExecutionsResponse, retError error) { - if wh.isShuttingDown() { - return nil, validate.ErrShuttingDown - } - - if listRequest == nil { - return nil, validate.ErrRequestNotSet - } - - if listRequest.GetDomain() == "" { - return nil, validate.ErrDomainNotSet - } - - if listRequest.StartTimeFilter == nil { - return nil, &types.BadRequestError{Message: "StartTimeFilter is required"} - } - - if listRequest.StartTimeFilter.EarliestTime == nil { - return nil, &types.BadRequestError{Message: "EarliestTime in StartTimeFilter is required"} - } - - if listRequest.StartTimeFilter.LatestTime == nil { - return nil, &types.BadRequestError{Message: "LatestTime in StartTimeFilter is required"} - } - - if listRequest.StartTimeFilter.GetEarliestTime() > listRequest.StartTimeFilter.GetLatestTime() { - return nil, &types.BadRequestError{Message: "EarliestTime in StartTimeFilter should not be larger than LatestTime"} - } - - filterCount := 0 - if listRequest.TypeFilter != nil { - filterCount++ - } - if listRequest.StatusFilter != nil { - filterCount++ - } - - if filterCount > 1 { - return nil, &types.BadRequestError{ - Message: "Only one of ExecutionFilter, TypeFilter or StatusFilter is allowed"} - } // If ExecutionFilter is provided with one of TypeFilter or StatusFilter, use ExecutionFilter and ignore other filter - - if listRequest.GetMaximumPageSize() <= 0 { - listRequest.MaximumPageSize = int32(wh.config.VisibilityMaxPageSize(listRequest.GetDomain())) - } - - if wh.isListRequestPageSizeTooLarge(listRequest.GetMaximumPageSize(), listRequest.GetDomain()) { - return nil, &types.BadRequestError{ - Message: fmt.Sprintf("Pagesize is larger than allow %d", wh.config.ESIndexMaxResultWindow())} - } - - domain := listRequest.GetDomain() - domainID, err := wh.GetDomainCache().GetDomainID(domain) - if err != nil { - return nil, err - } - - baseReq := persistence.ListWorkflowExecutionsRequest{ - DomainUUID: domainID, - Domain: domain, - PageSize: int(listRequest.GetMaximumPageSize()), - NextPageToken: listRequest.NextPageToken, - EarliestTime: listRequest.StartTimeFilter.GetEarliestTime(), - LatestTime: listRequest.StartTimeFilter.GetLatestTime(), - } - - var persistenceResp *persistence.ListWorkflowExecutionsResponse - if listRequest.ExecutionFilter != nil { - if wh.config.DisableListVisibilityByFilter(domain) { - err = validate.ErrNoPermission - } else { - persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByWorkflowID( - ctx, - &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ - ListWorkflowExecutionsRequest: baseReq, - WorkflowID: listRequest.ExecutionFilter.GetWorkflowID(), - }, - ) - } - wh.GetLogger().Debug("List closed workflow with filter", - tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByID) - } else if listRequest.TypeFilter != nil { - if wh.config.DisableListVisibilityByFilter(domain) { - err = validate.ErrNoPermission - } else { - persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByType( - ctx, - &persistence.ListWorkflowExecutionsByTypeRequest{ - ListWorkflowExecutionsRequest: baseReq, - WorkflowTypeName: listRequest.TypeFilter.GetName(), - }, - ) - } - wh.GetLogger().Debug("List closed workflow with filter", - tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByType) - } else if listRequest.StatusFilter != nil { - if wh.config.DisableListVisibilityByFilter(domain) { - err = validate.ErrNoPermission - } else { - persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByStatus( - ctx, - &persistence.ListClosedWorkflowExecutionsByStatusRequest{ - ListWorkflowExecutionsRequest: baseReq, - Status: listRequest.GetStatusFilter(), - }, - ) - } - wh.GetLogger().Debug("List closed workflow with filter", - tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByStatus) - } else { - persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutions(ctx, &baseReq) - } - - if err != nil { - return nil, err - } - - resp = &types.ListClosedWorkflowExecutionsResponse{} - resp.Executions = persistenceResp.Executions - resp.NextPageToken = persistenceResp.NextPageToken - return resp, nil -} - -// ListWorkflowExecutions - retrieves info for workflow executions in a domain -func (wh *WorkflowHandler) ListWorkflowExecutions( - ctx context.Context, - listRequest *types.ListWorkflowExecutionsRequest, -) (resp *types.ListWorkflowExecutionsResponse, retError error) { - if wh.isShuttingDown() { - return nil, validate.ErrShuttingDown - } - - if listRequest == nil { - return nil, validate.ErrRequestNotSet - } - - if listRequest.GetDomain() == "" { - return nil, validate.ErrDomainNotSet - } - - if listRequest.GetPageSize() <= 0 { - listRequest.PageSize = int32(wh.config.VisibilityMaxPageSize(listRequest.GetDomain())) - } - - if wh.isListRequestPageSizeTooLarge(listRequest.GetPageSize(), listRequest.GetDomain()) { - return nil, &types.BadRequestError{ - Message: fmt.Sprintf("Pagesize is larger than allow %d", wh.config.ESIndexMaxResultWindow())} - } - - validatedQuery, err := wh.visibilityQueryValidator.ValidateQuery(listRequest.GetQuery()) - if err != nil { - return nil, err - } - - domain := listRequest.GetDomain() - domainID, err := wh.GetDomainCache().GetDomainID(domain) - if err != nil { - return nil, err - } - - req := &persistence.ListWorkflowExecutionsByQueryRequest{ - DomainUUID: domainID, - Domain: domain, - PageSize: int(listRequest.GetPageSize()), - NextPageToken: listRequest.NextPageToken, - Query: validatedQuery, - } - persistenceResp, err := wh.GetVisibilityManager().ListWorkflowExecutions(ctx, req) - if err != nil { - return nil, err - } - - resp = &types.ListWorkflowExecutionsResponse{} - resp.Executions = persistenceResp.Executions - resp.NextPageToken = persistenceResp.NextPageToken - return resp, nil -} - // RestartWorkflowExecution - retrieves info for an existing workflow then restarts it func (wh *WorkflowHandler) RestartWorkflowExecution(ctx context.Context, request *types.RestartWorkflowExecutionRequest) (resp *types.RestartWorkflowExecutionResponse, retError error) { if wh.isShuttingDown() { @@ -3127,105 +2762,6 @@ func (wh *WorkflowHandler) RestartWorkflowExecution(ctx context.Context, request return resp, nil } -// ScanWorkflowExecutions - retrieves info for large amount of workflow executions in a domain without order -func (wh *WorkflowHandler) ScanWorkflowExecutions( - ctx context.Context, - listRequest *types.ListWorkflowExecutionsRequest, -) (resp *types.ListWorkflowExecutionsResponse, retError error) { - if wh.isShuttingDown() { - return nil, validate.ErrShuttingDown - } - - if listRequest == nil { - return nil, validate.ErrRequestNotSet - } - - if listRequest.GetDomain() == "" { - return nil, validate.ErrDomainNotSet - } - - if listRequest.GetPageSize() <= 0 { - listRequest.PageSize = int32(wh.config.VisibilityMaxPageSize(listRequest.GetDomain())) - } - - if wh.isListRequestPageSizeTooLarge(listRequest.GetPageSize(), listRequest.GetDomain()) { - return nil, &types.BadRequestError{ - Message: fmt.Sprintf("Pagesize is larger than allow %d", wh.config.ESIndexMaxResultWindow())} - } - - validatedQuery, err := wh.visibilityQueryValidator.ValidateQuery(listRequest.GetQuery()) - if err != nil { - return nil, err - } - - domain := listRequest.GetDomain() - domainID, err := wh.GetDomainCache().GetDomainID(domain) - if err != nil { - return nil, err - } - - req := &persistence.ListWorkflowExecutionsByQueryRequest{ - DomainUUID: domainID, - Domain: domain, - PageSize: int(listRequest.GetPageSize()), - NextPageToken: listRequest.NextPageToken, - Query: validatedQuery, - } - persistenceResp, err := wh.GetVisibilityManager().ScanWorkflowExecutions(ctx, req) - if err != nil { - return nil, err - } - - resp = &types.ListWorkflowExecutionsResponse{} - resp.Executions = persistenceResp.Executions - resp.NextPageToken = persistenceResp.NextPageToken - return resp, nil -} - -// CountWorkflowExecutions - count number of workflow executions in a domain -func (wh *WorkflowHandler) CountWorkflowExecutions( - ctx context.Context, - countRequest *types.CountWorkflowExecutionsRequest, -) (resp *types.CountWorkflowExecutionsResponse, retError error) { - if wh.isShuttingDown() { - return nil, validate.ErrShuttingDown - } - - if countRequest == nil { - return nil, validate.ErrRequestNotSet - } - - if countRequest.GetDomain() == "" { - return nil, validate.ErrDomainNotSet - } - - validatedQuery, err := wh.visibilityQueryValidator.ValidateQuery(countRequest.GetQuery()) - if err != nil { - return nil, err - } - - domain := countRequest.GetDomain() - domainID, err := wh.GetDomainCache().GetDomainID(domain) - if err != nil { - return nil, err - } - - req := &persistence.CountWorkflowExecutionsRequest{ - DomainUUID: domainID, - Domain: domain, - Query: validatedQuery, - } - persistenceResp, err := wh.GetVisibilityManager().CountWorkflowExecutions(ctx, req) - if err != nil { - return nil, err - } - - resp = &types.CountWorkflowExecutionsResponse{ - Count: persistenceResp.Count, - } - return resp, nil -} - // GetSearchAttributes return valid indexed keys func (wh *WorkflowHandler) GetSearchAttributes(ctx context.Context) (resp *types.GetSearchAttributesResponse, retError error) { if wh.isShuttingDown() { diff --git a/service/frontend/api/list_workflow_handlers.go b/service/frontend/api/list_workflow_handlers.go new file mode 100644 index 00000000000..987d19feeea --- /dev/null +++ b/service/frontend/api/list_workflow_handlers.go @@ -0,0 +1,367 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package api + +import ( + "context" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/archiver" + "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/service" + "github.com/uber/cadence/common/types" + "github.com/uber/cadence/service/frontend/validate" +) + +// CountWorkflowExecutions - count number of workflow executions in a domain +func (wh *WorkflowHandler) CountWorkflowExecutions( + ctx context.Context, + countRequest *types.CountWorkflowExecutionsRequest, +) (resp *types.CountWorkflowExecutionsResponse, retError error) { + if wh.isShuttingDown() { + return nil, validate.ErrShuttingDown + } + if err := wh.requestValidator.ValidateCountWorkflowExecutionsRequest(ctx, countRequest); err != nil { + return nil, err + } + validatedQuery, err := wh.visibilityQueryValidator.ValidateQuery(countRequest.GetQuery()) + if err != nil { + return nil, err + } + + domain := countRequest.GetDomain() + domainID, err := wh.GetDomainCache().GetDomainID(domain) + if err != nil { + return nil, err + } + + req := &persistence.CountWorkflowExecutionsRequest{ + DomainUUID: domainID, + Domain: domain, + Query: validatedQuery, + } + persistenceResp, err := wh.GetVisibilityManager().CountWorkflowExecutions(ctx, req) + if err != nil { + return nil, err + } + + resp = &types.CountWorkflowExecutionsResponse{ + Count: persistenceResp.Count, + } + return resp, nil +} + +// ScanWorkflowExecutions - retrieves info for large amount of workflow executions in a domain without order +func (wh *WorkflowHandler) ScanWorkflowExecutions( + ctx context.Context, + listRequest *types.ListWorkflowExecutionsRequest, +) (resp *types.ListWorkflowExecutionsResponse, retError error) { + if wh.isShuttingDown() { + return nil, validate.ErrShuttingDown + } + if err := wh.requestValidator.ValidateListWorkflowExecutionsRequest(ctx, listRequest); err != nil { + return nil, err + } + validatedQuery, err := wh.visibilityQueryValidator.ValidateQuery(listRequest.GetQuery()) + if err != nil { + return nil, err + } + + domain := listRequest.GetDomain() + domainID, err := wh.GetDomainCache().GetDomainID(domain) + if err != nil { + return nil, err + } + + req := &persistence.ListWorkflowExecutionsByQueryRequest{ + DomainUUID: domainID, + Domain: domain, + PageSize: int(listRequest.GetPageSize()), + NextPageToken: listRequest.NextPageToken, + Query: validatedQuery, + } + persistenceResp, err := wh.GetVisibilityManager().ScanWorkflowExecutions(ctx, req) + if err != nil { + return nil, err + } + + resp = &types.ListWorkflowExecutionsResponse{} + resp.Executions = persistenceResp.Executions + resp.NextPageToken = persistenceResp.NextPageToken + return resp, nil +} + +// ListOpenWorkflowExecutions - retrieves info for open workflow executions in a domain +func (wh *WorkflowHandler) ListOpenWorkflowExecutions( + ctx context.Context, + listRequest *types.ListOpenWorkflowExecutionsRequest, +) (resp *types.ListOpenWorkflowExecutionsResponse, retError error) { + if wh.isShuttingDown() { + return nil, validate.ErrShuttingDown + } + if err := wh.requestValidator.ValidateListOpenWorkflowExecutionsRequest(ctx, listRequest); err != nil { + return nil, err + } + domain := listRequest.GetDomain() + domainID, err := wh.GetDomainCache().GetDomainID(domain) + if err != nil { + return nil, err + } + + baseReq := persistence.ListWorkflowExecutionsRequest{ + DomainUUID: domainID, + Domain: domain, + PageSize: int(listRequest.GetMaximumPageSize()), + NextPageToken: listRequest.NextPageToken, + EarliestTime: listRequest.StartTimeFilter.GetEarliestTime(), + LatestTime: listRequest.StartTimeFilter.GetLatestTime(), + } + + var persistenceResp *persistence.ListWorkflowExecutionsResponse + if listRequest.ExecutionFilter != nil { + if wh.config.DisableListVisibilityByFilter(domain) { + err = validate.ErrNoPermission + } else { + persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutionsByWorkflowID( + ctx, + &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: baseReq, + WorkflowID: listRequest.ExecutionFilter.GetWorkflowID(), + }) + } + wh.GetLogger().Debug("List open workflow with filter", + tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByID) + } else if listRequest.TypeFilter != nil { + if wh.config.DisableListVisibilityByFilter(domain) { + err = validate.ErrNoPermission + } else { + persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutionsByType( + ctx, + &persistence.ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: baseReq, + WorkflowTypeName: listRequest.TypeFilter.GetName(), + }, + ) + } + wh.GetLogger().Debug("List open workflow with filter", + tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByType) + } else { + persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutions(ctx, &baseReq) + } + + if err != nil { + return nil, err + } + + resp = &types.ListOpenWorkflowExecutionsResponse{} + resp.Executions = persistenceResp.Executions + resp.NextPageToken = persistenceResp.NextPageToken + return resp, nil +} + +// ListArchivedWorkflowExecutions - retrieves archived info for closed workflow executions in a domain +func (wh *WorkflowHandler) ListArchivedWorkflowExecutions( + ctx context.Context, + listRequest *types.ListArchivedWorkflowExecutionsRequest, +) (resp *types.ListArchivedWorkflowExecutionsResponse, retError error) { + if wh.isShuttingDown() { + return nil, validate.ErrShuttingDown + } + if err := wh.requestValidator.ValidateListArchivedWorkflowExecutionsRequest(ctx, listRequest); err != nil { + return nil, err + } + if !wh.GetArchivalMetadata().GetVisibilityConfig().ClusterConfiguredForArchival() { + return nil, &types.BadRequestError{Message: "Cluster is not configured for visibility archival"} + } + + if !wh.GetArchivalMetadata().GetVisibilityConfig().ReadEnabled() { + return nil, &types.BadRequestError{Message: "Cluster is not configured for reading archived visibility records"} + } + + entry, err := wh.GetDomainCache().GetDomain(listRequest.GetDomain()) + if err != nil { + return nil, err + } + + if entry.GetConfig().VisibilityArchivalStatus != types.ArchivalStatusEnabled { + return nil, &types.BadRequestError{Message: "Domain is not configured for visibility archival"} + } + + URI, err := archiver.NewURI(entry.GetConfig().VisibilityArchivalURI) + if err != nil { + return nil, err + } + + visibilityArchiver, err := wh.GetArchiverProvider().GetVisibilityArchiver(URI.Scheme(), service.Frontend) + if err != nil { + return nil, err + } + + archiverRequest := &archiver.QueryVisibilityRequest{ + DomainID: entry.GetInfo().ID, + PageSize: int(listRequest.GetPageSize()), + NextPageToken: listRequest.NextPageToken, + Query: listRequest.GetQuery(), + } + + archiverResponse, err := visibilityArchiver.Query(ctx, URI, archiverRequest) + if err != nil { + return nil, err + } + + // special handling of ExecutionTime for cron or retry + for _, execution := range archiverResponse.Executions { + if execution.GetExecutionTime() == 0 { + execution.ExecutionTime = common.Int64Ptr(execution.GetStartTime()) + } + } + + return &types.ListArchivedWorkflowExecutionsResponse{ + Executions: archiverResponse.Executions, + NextPageToken: archiverResponse.NextPageToken, + }, nil +} + +// ListClosedWorkflowExecutions - retrieves info for closed workflow executions in a domain +func (wh *WorkflowHandler) ListClosedWorkflowExecutions( + ctx context.Context, + listRequest *types.ListClosedWorkflowExecutionsRequest, +) (resp *types.ListClosedWorkflowExecutionsResponse, retError error) { + if wh.isShuttingDown() { + return nil, validate.ErrShuttingDown + } + if err := wh.requestValidator.ValidateListClosedWorkflowExecutionsRequest(ctx, listRequest); err != nil { + return nil, err + } + domain := listRequest.GetDomain() + domainID, err := wh.GetDomainCache().GetDomainID(domain) + if err != nil { + return nil, err + } + + baseReq := persistence.ListWorkflowExecutionsRequest{ + DomainUUID: domainID, + Domain: domain, + PageSize: int(listRequest.GetMaximumPageSize()), + NextPageToken: listRequest.NextPageToken, + EarliestTime: listRequest.StartTimeFilter.GetEarliestTime(), + LatestTime: listRequest.StartTimeFilter.GetLatestTime(), + } + + var persistenceResp *persistence.ListWorkflowExecutionsResponse + if listRequest.ExecutionFilter != nil { + if wh.config.DisableListVisibilityByFilter(domain) { + err = validate.ErrNoPermission + } else { + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByWorkflowID( + ctx, + &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: baseReq, + WorkflowID: listRequest.ExecutionFilter.GetWorkflowID(), + }, + ) + } + wh.GetLogger().Debug("List closed workflow with filter", + tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByID) + } else if listRequest.TypeFilter != nil { + if wh.config.DisableListVisibilityByFilter(domain) { + err = validate.ErrNoPermission + } else { + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByType( + ctx, + &persistence.ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: baseReq, + WorkflowTypeName: listRequest.TypeFilter.GetName(), + }, + ) + } + wh.GetLogger().Debug("List closed workflow with filter", + tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByType) + } else if listRequest.StatusFilter != nil { + if wh.config.DisableListVisibilityByFilter(domain) { + err = validate.ErrNoPermission + } else { + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByStatus( + ctx, + &persistence.ListClosedWorkflowExecutionsByStatusRequest{ + ListWorkflowExecutionsRequest: baseReq, + Status: listRequest.GetStatusFilter(), + }, + ) + } + wh.GetLogger().Debug("List closed workflow with filter", + tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByStatus) + } else { + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutions(ctx, &baseReq) + } + + if err != nil { + return nil, err + } + + resp = &types.ListClosedWorkflowExecutionsResponse{} + resp.Executions = persistenceResp.Executions + resp.NextPageToken = persistenceResp.NextPageToken + return resp, nil +} + +// ListWorkflowExecutions - retrieves info for workflow executions in a domain +func (wh *WorkflowHandler) ListWorkflowExecutions( + ctx context.Context, + listRequest *types.ListWorkflowExecutionsRequest, +) (resp *types.ListWorkflowExecutionsResponse, retError error) { + if wh.isShuttingDown() { + return nil, validate.ErrShuttingDown + } + if err := wh.requestValidator.ValidateListWorkflowExecutionsRequest(ctx, listRequest); err != nil { + return nil, err + } + validatedQuery, err := wh.visibilityQueryValidator.ValidateQuery(listRequest.GetQuery()) + if err != nil { + return nil, err + } + + domain := listRequest.GetDomain() + domainID, err := wh.GetDomainCache().GetDomainID(domain) + if err != nil { + return nil, err + } + + req := &persistence.ListWorkflowExecutionsByQueryRequest{ + DomainUUID: domainID, + Domain: domain, + PageSize: int(listRequest.GetPageSize()), + NextPageToken: listRequest.NextPageToken, + Query: validatedQuery, + } + persistenceResp, err := wh.GetVisibilityManager().ListWorkflowExecutions(ctx, req) + if err != nil { + return nil, err + } + + resp = &types.ListWorkflowExecutionsResponse{} + resp.Executions = persistenceResp.Executions + resp.NextPageToken = persistenceResp.NextPageToken + return resp, nil +} diff --git a/service/frontend/api/request_validator.go b/service/frontend/api/request_validator.go index 4ad92e2f09e..364c9e659c9 100644 --- a/service/frontend/api/request_validator.go +++ b/service/frontend/api/request_validator.go @@ -26,6 +26,7 @@ package api import ( "context" + "fmt" "github.com/uber/cadence/common" "github.com/uber/cadence/common/log" @@ -43,6 +44,11 @@ type ( ValidateListTaskListPartitionsRequest(context.Context, *types.ListTaskListPartitionsRequest) error ValidateGetTaskListsByDomainRequest(context.Context, *types.GetTaskListsByDomainRequest) error ValidateResetStickyTaskListRequest(context.Context, *types.ResetStickyTaskListRequest) error + ValidateCountWorkflowExecutionsRequest(context.Context, *types.CountWorkflowExecutionsRequest) error + ValidateListWorkflowExecutionsRequest(context.Context, *types.ListWorkflowExecutionsRequest) error + ValidateListOpenWorkflowExecutionsRequest(context.Context, *types.ListOpenWorkflowExecutionsRequest) error + ValidateListArchivedWorkflowExecutionsRequest(context.Context, *types.ListArchivedWorkflowExecutionsRequest) error + ValidateListClosedWorkflowExecutionsRequest(context.Context, *types.ListClosedWorkflowExecutionsRequest) error } requestValidatorImpl struct { @@ -64,7 +70,6 @@ func (v *requestValidatorImpl) validateTaskList(t *types.TaskList, scope metrics if t == nil || t.GetName() == "" { return validate.ErrTaskListNotSet } - if !common.IsValidIDLength( t.GetName(), scope, @@ -79,6 +84,11 @@ func (v *requestValidatorImpl) validateTaskList(t *types.TaskList, scope metrics return nil } +func (v *requestValidatorImpl) isListRequestPageSizeTooLarge(pageSize int32, domain string) bool { + return common.IsAdvancedVisibilityReadingEnabled(v.config.EnableReadVisibilityFromES(domain), v.config.IsAdvancedVisConfigExist) && + pageSize > int32(v.config.ESIndexMaxResultWindow()) +} + func (v *requestValidatorImpl) ValidateRefreshWorkflowTasksRequest(ctx context.Context, req *types.RefreshWorkflowTasksRequest) error { if req == nil { return validate.ErrRequestNotSet @@ -132,3 +142,118 @@ func (v *requestValidatorImpl) ValidateResetStickyTaskListRequest(ctx context.Co wfExecution := resetRequest.GetExecution() return validate.CheckExecution(wfExecution) } + +func (v *requestValidatorImpl) ValidateCountWorkflowExecutionsRequest(ctx context.Context, countRequest *types.CountWorkflowExecutionsRequest) error { + if countRequest == nil { + return validate.ErrRequestNotSet + } + if countRequest.GetDomain() == "" { + return validate.ErrDomainNotSet + } + return nil +} + +func (v *requestValidatorImpl) ValidateListWorkflowExecutionsRequest(ctx context.Context, listRequest *types.ListWorkflowExecutionsRequest) error { + if listRequest == nil { + return validate.ErrRequestNotSet + } + if listRequest.GetDomain() == "" { + return validate.ErrDomainNotSet + } + if listRequest.GetPageSize() <= 0 { + listRequest.PageSize = int32(v.config.VisibilityMaxPageSize(listRequest.GetDomain())) + } + if v.isListRequestPageSizeTooLarge(listRequest.GetPageSize(), listRequest.GetDomain()) { + return &types.BadRequestError{Message: fmt.Sprintf("Pagesize is larger than allow %d", v.config.ESIndexMaxResultWindow())} + } + return nil +} + +func (v *requestValidatorImpl) ValidateListOpenWorkflowExecutionsRequest(ctx context.Context, listRequest *types.ListOpenWorkflowExecutionsRequest) error { + if listRequest == nil { + return validate.ErrRequestNotSet + } + if listRequest.GetDomain() == "" { + return validate.ErrDomainNotSet + } + if listRequest.StartTimeFilter == nil { + return &types.BadRequestError{Message: "StartTimeFilter is required"} + } + if listRequest.StartTimeFilter.EarliestTime == nil { + return &types.BadRequestError{Message: "EarliestTime in StartTimeFilter is required"} + } + if listRequest.StartTimeFilter.LatestTime == nil { + return &types.BadRequestError{Message: "LatestTime in StartTimeFilter is required"} + } + if listRequest.StartTimeFilter.GetEarliestTime() > listRequest.StartTimeFilter.GetLatestTime() { + return &types.BadRequestError{Message: "EarliestTime in StartTimeFilter should not be larger than LatestTime"} + } + if listRequest.ExecutionFilter != nil && listRequest.TypeFilter != nil { + return &types.BadRequestError{Message: "Only one of ExecutionFilter or TypeFilter is allowed"} + } + if listRequest.GetMaximumPageSize() <= 0 { + listRequest.MaximumPageSize = int32(v.config.VisibilityMaxPageSize(listRequest.GetDomain())) + } + if v.isListRequestPageSizeTooLarge(listRequest.GetMaximumPageSize(), listRequest.GetDomain()) { + return &types.BadRequestError{Message: fmt.Sprintf("Pagesize is larger than allow %d", v.config.ESIndexMaxResultWindow())} + } + return nil +} + +func (v *requestValidatorImpl) ValidateListArchivedWorkflowExecutionsRequest(ctx context.Context, listRequest *types.ListArchivedWorkflowExecutionsRequest) error { + if listRequest == nil { + return validate.ErrRequestNotSet + } + if listRequest.GetDomain() == "" { + return validate.ErrDomainNotSet + } + if listRequest.GetPageSize() <= 0 { + listRequest.PageSize = int32(v.config.VisibilityMaxPageSize(listRequest.GetDomain())) + } + maxPageSize := v.config.VisibilityArchivalQueryMaxPageSize() + if int(listRequest.GetPageSize()) > maxPageSize { + return &types.BadRequestError{Message: fmt.Sprintf("Pagesize is larger than allowed %d", maxPageSize)} + } + return nil +} + +func (v *requestValidatorImpl) ValidateListClosedWorkflowExecutionsRequest(ctx context.Context, listRequest *types.ListClosedWorkflowExecutionsRequest) error { + if listRequest == nil { + return validate.ErrRequestNotSet + } + if listRequest.GetDomain() == "" { + return validate.ErrDomainNotSet + } + if listRequest.StartTimeFilter == nil { + return &types.BadRequestError{Message: "StartTimeFilter is required"} + } + if listRequest.StartTimeFilter.EarliestTime == nil { + return &types.BadRequestError{Message: "EarliestTime in StartTimeFilter is required"} + } + if listRequest.StartTimeFilter.LatestTime == nil { + return &types.BadRequestError{Message: "LatestTime in StartTimeFilter is required"} + } + if listRequest.StartTimeFilter.GetEarliestTime() > listRequest.StartTimeFilter.GetLatestTime() { + return &types.BadRequestError{Message: "EarliestTime in StartTimeFilter should not be larger than LatestTime"} + } + filterCount := 0 + if listRequest.ExecutionFilter != nil { + filterCount++ + } + if listRequest.TypeFilter != nil { + filterCount++ + } + if listRequest.StatusFilter != nil { + filterCount++ + } + if filterCount > 1 { + return &types.BadRequestError{Message: "Only one of ExecutionFilter, TypeFilter or StatusFilter is allowed"} + } // If ExecutionFilter is provided with one of TypeFilter or StatusFilter, use ExecutionFilter and ignore other filter + if listRequest.GetMaximumPageSize() <= 0 { + listRequest.MaximumPageSize = int32(v.config.VisibilityMaxPageSize(listRequest.GetDomain())) + } + if v.isListRequestPageSizeTooLarge(listRequest.GetMaximumPageSize(), listRequest.GetDomain()) { + return &types.BadRequestError{Message: fmt.Sprintf("Pagesize is larger than allow %d", v.config.ESIndexMaxResultWindow())} + } + return nil +} diff --git a/service/frontend/api/request_validator_mock.go b/service/frontend/api/request_validator_mock.go index 240e2536d25..658ab6f639e 100644 --- a/service/frontend/api/request_validator_mock.go +++ b/service/frontend/api/request_validator_mock.go @@ -58,6 +58,20 @@ func (m *MockRequestValidator) EXPECT() *MockRequestValidatorMockRecorder { return m.recorder } +// ValidateCountWorkflowExecutionsRequest mocks base method. +func (m *MockRequestValidator) ValidateCountWorkflowExecutionsRequest(arg0 context.Context, arg1 *types.CountWorkflowExecutionsRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateCountWorkflowExecutionsRequest", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateCountWorkflowExecutionsRequest indicates an expected call of ValidateCountWorkflowExecutionsRequest. +func (mr *MockRequestValidatorMockRecorder) ValidateCountWorkflowExecutionsRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateCountWorkflowExecutionsRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateCountWorkflowExecutionsRequest), arg0, arg1) +} + // ValidateDescribeTaskListRequest mocks base method. func (m *MockRequestValidator) ValidateDescribeTaskListRequest(arg0 context.Context, arg1 *types.DescribeTaskListRequest) error { m.ctrl.T.Helper() @@ -86,6 +100,48 @@ func (mr *MockRequestValidatorMockRecorder) ValidateGetTaskListsByDomainRequest( return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateGetTaskListsByDomainRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateGetTaskListsByDomainRequest), arg0, arg1) } +// ValidateListArchivedWorkflowExecutionsRequest mocks base method. +func (m *MockRequestValidator) ValidateListArchivedWorkflowExecutionsRequest(arg0 context.Context, arg1 *types.ListArchivedWorkflowExecutionsRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateListArchivedWorkflowExecutionsRequest", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateListArchivedWorkflowExecutionsRequest indicates an expected call of ValidateListArchivedWorkflowExecutionsRequest. +func (mr *MockRequestValidatorMockRecorder) ValidateListArchivedWorkflowExecutionsRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateListArchivedWorkflowExecutionsRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateListArchivedWorkflowExecutionsRequest), arg0, arg1) +} + +// ValidateListClosedWorkflowExecutionsRequest mocks base method. +func (m *MockRequestValidator) ValidateListClosedWorkflowExecutionsRequest(arg0 context.Context, arg1 *types.ListClosedWorkflowExecutionsRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateListClosedWorkflowExecutionsRequest", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateListClosedWorkflowExecutionsRequest indicates an expected call of ValidateListClosedWorkflowExecutionsRequest. +func (mr *MockRequestValidatorMockRecorder) ValidateListClosedWorkflowExecutionsRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateListClosedWorkflowExecutionsRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateListClosedWorkflowExecutionsRequest), arg0, arg1) +} + +// ValidateListOpenWorkflowExecutionsRequest mocks base method. +func (m *MockRequestValidator) ValidateListOpenWorkflowExecutionsRequest(arg0 context.Context, arg1 *types.ListOpenWorkflowExecutionsRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateListOpenWorkflowExecutionsRequest", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateListOpenWorkflowExecutionsRequest indicates an expected call of ValidateListOpenWorkflowExecutionsRequest. +func (mr *MockRequestValidatorMockRecorder) ValidateListOpenWorkflowExecutionsRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateListOpenWorkflowExecutionsRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateListOpenWorkflowExecutionsRequest), arg0, arg1) +} + // ValidateListTaskListPartitionsRequest mocks base method. func (m *MockRequestValidator) ValidateListTaskListPartitionsRequest(arg0 context.Context, arg1 *types.ListTaskListPartitionsRequest) error { m.ctrl.T.Helper() @@ -100,6 +156,20 @@ func (mr *MockRequestValidatorMockRecorder) ValidateListTaskListPartitionsReques return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateListTaskListPartitionsRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateListTaskListPartitionsRequest), arg0, arg1) } +// ValidateListWorkflowExecutionsRequest mocks base method. +func (m *MockRequestValidator) ValidateListWorkflowExecutionsRequest(arg0 context.Context, arg1 *types.ListWorkflowExecutionsRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateListWorkflowExecutionsRequest", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateListWorkflowExecutionsRequest indicates an expected call of ValidateListWorkflowExecutionsRequest. +func (mr *MockRequestValidatorMockRecorder) ValidateListWorkflowExecutionsRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateListWorkflowExecutionsRequest", reflect.TypeOf((*MockRequestValidator)(nil).ValidateListWorkflowExecutionsRequest), arg0, arg1) +} + // ValidateRefreshWorkflowTasksRequest mocks base method. func (m *MockRequestValidator) ValidateRefreshWorkflowTasksRequest(arg0 context.Context, arg1 *types.RefreshWorkflowTasksRequest) error { m.ctrl.T.Helper() diff --git a/service/frontend/api/request_validator_test.go b/service/frontend/api/request_validator_test.go index 20f993814d1..6335f8efa7e 100644 --- a/service/frontend/api/request_validator_test.go +++ b/service/frontend/api/request_validator_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/uber/cadence/common" "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/metrics" @@ -46,7 +47,7 @@ func setupMocksForRequestValidator(t *testing.T) (*requestValidatorImpl, *mockDe logger, ), numHistoryShards, - false, + true, "hostname", ) deps := &mockDeps{ @@ -393,3 +394,425 @@ func TestValidateValidateResetStickyTaskListRequest(t *testing.T) { }) } } + +func TestValidateCountWorkflowExecutionsRequest(t *testing.T) { + testCases := []struct { + name string + req *types.CountWorkflowExecutionsRequest + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.CountWorkflowExecutionsRequest{ + Domain: "domain", + }, + expectError: false, + }, + { + name: "not set", + req: nil, + expectError: true, + expectedError: "Request is nil.", + }, + { + name: "domain not set", + req: &types.CountWorkflowExecutionsRequest{ + Domain: "", + }, + expectError: true, + expectedError: "Domain not set on request.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, _ := setupMocksForRequestValidator(t) + + err := v.ValidateCountWorkflowExecutionsRequest(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateListWorkflowExecutionsRequest(t *testing.T) { + testCases := []struct { + name string + req *types.ListWorkflowExecutionsRequest + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.ListWorkflowExecutionsRequest{ + Domain: "domain", + }, + expectError: false, + }, + { + name: "not set", + req: nil, + expectError: true, + expectedError: "Request is nil.", + }, + { + name: "domain not set", + req: &types.ListWorkflowExecutionsRequest{ + Domain: "", + }, + expectError: true, + expectedError: "Domain not set on request.", + }, + { + name: "page size too large", + req: &types.ListWorkflowExecutionsRequest{ + Domain: "domain", + PageSize: 101, + }, + expectError: true, + expectedError: "Pagesize is larger than allow 100", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, deps := setupMocksForRequestValidator(t) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendESIndexMaxResultWindow, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendVisibilityMaxPageSize, 10)) + + err := v.ValidateListWorkflowExecutionsRequest(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, int32(10), tc.req.GetPageSize()) + } + }) + } +} + +func TestValidateListOpenWorkflowExecutionsRequest(t *testing.T) { + testCases := []struct { + name string + req *types.ListOpenWorkflowExecutionsRequest + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + MaximumPageSize: 0, + }, + expectError: false, + }, + { + name: "not set", + req: nil, + expectError: true, + expectedError: "Request is nil.", + }, + { + name: "domain not set", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "", + }, + expectError: true, + expectedError: "Domain not set on request.", + }, + { + name: "startTimeFilter not set", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + }, + expectError: true, + expectedError: "StartTimeFilter is required", + }, + { + name: "Earliest time not set", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{}, + }, + expectError: true, + expectedError: "EarliestTime in StartTimeFilter is required", + }, + { + name: "Latest time not set", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + }, + }, + expectError: true, + expectedError: "LatestTime in StartTimeFilter is required", + }, + { + name: "earliest time later than latest time", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(3)), + LatestTime: common.Ptr(int64(2)), + }, + }, + expectError: true, + expectedError: "EarliestTime in StartTimeFilter should not be larger than LatestTime", + }, + { + name: "both execution and type filter are specified", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + ExecutionFilter: &types.WorkflowExecutionFilter{}, + TypeFilter: &types.WorkflowTypeFilter{}, + }, + expectError: true, + expectedError: "Only one of ExecutionFilter or TypeFilter is allowed", + }, + { + name: "page size too large", + req: &types.ListOpenWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + MaximumPageSize: 101, + }, + expectError: true, + expectedError: "Pagesize is larger than allow 100", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, deps := setupMocksForRequestValidator(t) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendESIndexMaxResultWindow, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendVisibilityMaxPageSize, 10)) + + err := v.ValidateListOpenWorkflowExecutionsRequest(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, int32(10), tc.req.GetMaximumPageSize()) + } + }) + } +} + +func TestValidateListClosedWorkflowExecutionsRequest(t *testing.T) { + testCases := []struct { + name string + req *types.ListClosedWorkflowExecutionsRequest + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + MaximumPageSize: 0, + }, + expectError: false, + }, + { + name: "not set", + req: nil, + expectError: true, + expectedError: "Request is nil.", + }, + { + name: "domain not set", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "", + }, + expectError: true, + expectedError: "Domain not set on request.", + }, + { + name: "startTimeFilter not set", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + }, + expectError: true, + expectedError: "StartTimeFilter is required", + }, + { + name: "Earliest time not set", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{}, + }, + expectError: true, + expectedError: "EarliestTime in StartTimeFilter is required", + }, + { + name: "Latest time not set", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + }, + }, + expectError: true, + expectedError: "LatestTime in StartTimeFilter is required", + }, + { + name: "earliest time later than latest time", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(3)), + LatestTime: common.Ptr(int64(2)), + }, + }, + expectError: true, + expectedError: "EarliestTime in StartTimeFilter should not be larger than LatestTime", + }, + { + name: "both execution and type filter are specified", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + ExecutionFilter: &types.WorkflowExecutionFilter{}, + TypeFilter: &types.WorkflowTypeFilter{}, + }, + expectError: true, + expectedError: "Only one of ExecutionFilter, TypeFilter or StatusFilter is allowed", + }, + { + name: "both execution and status filter are specified", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + ExecutionFilter: &types.WorkflowExecutionFilter{}, + StatusFilter: types.WorkflowExecutionCloseStatusFailed.Ptr(), + }, + expectError: true, + expectedError: "Only one of ExecutionFilter, TypeFilter or StatusFilter is allowed", + }, + { + name: "both type and status filter are specified", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + TypeFilter: &types.WorkflowTypeFilter{}, + StatusFilter: types.WorkflowExecutionCloseStatusFailed.Ptr(), + }, + expectError: true, + expectedError: "Only one of ExecutionFilter, TypeFilter or StatusFilter is allowed", + }, + { + name: "page size too large", + req: &types.ListClosedWorkflowExecutionsRequest{ + Domain: "domain", + StartTimeFilter: &types.StartTimeFilter{ + EarliestTime: common.Ptr(int64(1)), + LatestTime: common.Ptr(int64(2)), + }, + MaximumPageSize: 101, + }, + expectError: true, + expectedError: "Pagesize is larger than allow 100", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, deps := setupMocksForRequestValidator(t) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendESIndexMaxResultWindow, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendVisibilityMaxPageSize, 10)) + + err := v.ValidateListClosedWorkflowExecutionsRequest(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, int32(10), tc.req.GetMaximumPageSize()) + } + }) + } +} + +func TestValidateListArchivedWorkflowExecutionsRequest(t *testing.T) { + testCases := []struct { + name string + req *types.ListArchivedWorkflowExecutionsRequest + expectError bool + expectedError string + }{ + { + name: "success", + req: &types.ListArchivedWorkflowExecutionsRequest{ + Domain: "domain", + }, + expectError: false, + }, + { + name: "not set", + req: nil, + expectError: true, + expectedError: "Request is nil.", + }, + { + name: "domain not set", + req: &types.ListArchivedWorkflowExecutionsRequest{ + Domain: "", + }, + expectError: true, + expectedError: "Domain not set on request.", + }, + { + name: "page size too large", + req: &types.ListArchivedWorkflowExecutionsRequest{ + Domain: "domain", + PageSize: 101, + }, + expectError: true, + expectedError: "Pagesize is larger than allowed 100", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, deps := setupMocksForRequestValidator(t) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.VisibilityArchivalQueryMaxPageSize, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.FrontendVisibilityMaxPageSize, 10)) + + err := v.ValidateListArchivedWorkflowExecutionsRequest(context.Background(), tc.req) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, int32(10), tc.req.GetPageSize()) + } + }) + } +}