Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallel handling of batch query #218 #219

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"strconv"
"strings"
"sync"

"github.com/nautilus/graphql"
)
Expand All @@ -27,6 +28,8 @@ type HTTPOperation struct {
} `json:"extensions"`
}

type setResultFunc func(r map[string]interface{})

func formatErrors(err error) map[string]interface{} {
return formatErrorsWithCode(nil, err, "UNKNOWN_ERROR")
}
Expand Down Expand Up @@ -70,12 +73,14 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) {
/// Handle the operations regardless of the request method

// we have to respond to each operation in the right order
results := []map[string]interface{}{}
results := make([]map[string]interface{}, len(operations))
opWg := new(sync.WaitGroup)
opMutex := new(sync.Mutex)

// the status code to report
statusCode := http.StatusOK

for _, operation := range operations {
for opNum, operation := range operations {
// there might be a query plan cache key embedded in the operation
cacheKey := ""
if operation.Extensions.QueryPlanCache != nil {
Expand All @@ -85,10 +90,8 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) {
// if there is no query or cache key
if operation.Query == "" && cacheKey == "" {
statusCode = http.StatusUnprocessableEntity
results = append(
results,
formatErrorsWithCode(nil, errors.New("could not find query body"), "BAD_USER_INPUT"),
)
results[opNum] = formatErrorsWithCode(nil, errors.New("could not find query body"), "BAD_USER_INPUT")

continue
}

Expand Down Expand Up @@ -116,32 +119,12 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) {
return
}

// fire the query with the request context passed through to execution
result, err := g.Execute(requestContext, plan)
if err != nil {
results = append(results, formatErrorsWithCode(result, err, "INTERNAL_SERVER_ERROR"))

continue
}

// the result for this operation
payload := map[string]interface{}{"data": result}

// if there was a cache key associated with this query
if requestContext.CacheKey != "" {
// embed the cache key in the response
payload["extensions"] = map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha265Hash": requestContext.CacheKey,
"version": "1",
},
}
}

// add this result to the list
results = append(results, payload)
opWg.Add(1)
go g.executeRequest(requestContext, plan, opWg, g.setResultFunc(opNum, results, opMutex))
}

opWg.Wait()

// the final result depends on whether we are executing in batch mode or not
var finalResponse interface{}
if batchMode {
Expand All @@ -165,6 +148,43 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) {
emitResponse(w, statusCode, string(response))
}

func (g *Gateway) setResultFunc(opNum int, results []map[string]interface{}, opMutex *sync.Mutex) setResultFunc {
return func(r map[string]interface{}) {
opMutex.Lock()
defer opMutex.Unlock()
results[opNum] = r
}
}

func (g *Gateway) executeRequest(requestContext *RequestContext, plan QueryPlanList, opWg *sync.WaitGroup, setResult setResultFunc) {
defer opWg.Done()

// fire the query with the request context passed through to execution
result, err := g.Execute(requestContext, plan)
if err != nil {
setResult(formatErrorsWithCode(result, err, "INTERNAL_SERVER_ERROR"))

return
}

// the result for this operation
payload := map[string]interface{}{"data": result}

// if there was a cache key associated with this query
if requestContext.CacheKey != "" {
// embed the cache key in the response
payload["extensions"] = map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha265Hash": requestContext.CacheKey,
"version": "1",
},
}
}

// add this result to the list
setResult(payload)
}

// Parses request to operations (single or batch mode).
// Returns an error and an error status code if the request is invalid.
func parseRequest(r *http.Request) (operations []*HTTPOperation, batchMode bool, errStatusCode int, payloadErr error) {
Expand Down
93 changes: 93 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strconv"
"strings"
"testing"
"time"

"golang.org/x/net/html"

Expand Down Expand Up @@ -971,6 +972,98 @@ func TestGraphQLHandler_postBatchWithMultipleFiles(t *testing.T) {
assert.Equal(t, http.StatusOK, result.StatusCode)
}

func TestGraphQLHandler_postBatchParallel(t *testing.T) {
t.Parallel()
schema, err := graphql.LoadSchema(`
type Query {
queryA: String!
queryB: String!
}
`)
assert.NoError(t, err)

// create gateway schema we can test against
gateway, err := New([]*graphql.RemoteSchema{
{Schema: schema, URL: "url-file-upload"},
}, WithExecutor(ExecutorFunc(
func(ec *ExecutionContext) (map[string]interface{}, error) {
if ec.Plan.Operation.Name == "queryAOperation" {
time.Sleep(50 * time.Millisecond)
return map[string]interface{}{
"queryA": "resultA",
}, nil
}
if ec.Plan.Operation.Name == "queryBOperation" {
return map[string]interface{}{
"queryB": "resultB",
}, nil
}

assert.Fail(t, "unexpected operation name", ec.Plan.Operation.Name)
return nil, nil
},
)))

if err != nil {
t.Error(err.Error())
return
}

request := httptest.NewRequest("POST", "/graphql", strings.NewReader(`[
{
"query": "query queryAOperation { queryA }",
"variables": null
},
{
"query": "query queryBOperation { queryB }",
"variables": null
}
]`))

// a recorder so we can check what the handler responded with
responseRecorder := httptest.NewRecorder()

// call the http hander
gateway.GraphQLHandler(responseRecorder, request)

// make sure we got correct order in response
response := responseRecorder.Result()
assert.Equal(t, http.StatusOK, response.StatusCode)

// read the body
body, err := io.ReadAll(response.Body)
assert.NoError(t, response.Body.Close())

if err != nil {
t.Error(err.Error())
return
}

result := []map[string]interface{}{}
err = json.Unmarshal(body, &result)
if err != nil {
t.Error(err.Error())
return
}

// we should have gotten 2 responses
if !assert.Len(t, result, 2) {
return
}

// make sure there were no errors in the first query
if firstQuery := result[0]; assert.Nil(t, firstQuery["errors"]) {
// make sure it has the right id
assert.Equal(t, map[string]interface{}{"queryA": "resultA"}, firstQuery["data"])
}

// make sure there were no errors in the second query
if secondQuery := result[1]; assert.Nil(t, secondQuery["errors"]) {
// make sure it has the right id
assert.Equal(t, map[string]interface{}{"queryB": "resultB"}, secondQuery["data"])
}
}

func TestGraphQLHandler_postFilesWithError(t *testing.T) {
t.Parallel()
schema, err := graphql.LoadSchema(`
Expand Down
Loading