Skip to content

Commit

Permalink
Add support for Live Queries (#88)
Browse files Browse the repository at this point in the history
Co-authored-by: ElecTwix <[email protected]>
Co-authored-by: Hugh Kaznowski <[email protected]>
  • Loading branch information
3 people authored Oct 17, 2023
1 parent 4e3ec82 commit 78c52c6
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 22 deletions.
16 changes: 12 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package surrealdb
import (
"fmt"

"github.com/surrealdb/surrealdb.go/pkg/model"

"github.com/surrealdb/surrealdb.go/pkg/constants"
"github.com/surrealdb/surrealdb.go/pkg/websocket"
)
Expand Down Expand Up @@ -57,12 +59,13 @@ func (db *DB) Authenticate(token string) (interface{}, error) {

// --------------------------------------------------

func (db *DB) Live(table string) (interface{}, error) {
return db.send("live", table)
func (db *DB) Live(table string) (string, error) {
id, err := db.send("live", table)
return id.(string), err
}

func (db *DB) Kill(query string) (interface{}, error) {
return db.send("kill", query)
func (db *DB) Kill(liveQueryID string) (interface{}, error) {
return db.send("kill", liveQueryID)
}

func (db *DB) Let(key string, val interface{}) (interface{}, error) {
Expand Down Expand Up @@ -109,6 +112,11 @@ func (db *DB) Insert(what string, data interface{}) (interface{}, error) {
return db.send("insert", what, data)
}

// LiveNotifications returns a channel for live query.
func (db *DB) LiveNotifications(liveQueryID string) (chan model.Notification, error) {
return db.ws.LiveNotifications(liveQueryID)
}

// --------------------------------------------------
// Private methods
// --------------------------------------------------
Expand Down
52 changes: 52 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/surrealdb/surrealdb.go/pkg/model"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/surrealdb/surrealdb.go"
Expand Down Expand Up @@ -124,6 +127,55 @@ func signin(s *SurrealDBTestSuite) interface{} {
return signin
}

func (s *SurrealDBTestSuite) TestLiveViaMethod() {
live, err := s.db.Live("users")

defer func() {
_, err = s.db.Kill(live)
s.Require().NoError(err)
}()

notifications, er := s.db.LiveNotifications(live)
// create a user
s.Require().NoError(er)
_, e := s.db.Create("users", map[string]interface{}{
"username": "johnny",
"password": "123",
})
s.Require().NoError(e)
notification := <-notifications
s.Require().Equal(model.CreateAction, notification.Action)
s.Require().Equal(live, notification.ID)
}

func (s *SurrealDBTestSuite) TestLiveViaQuery() {
liveResponse, err := s.db.Query("LIVE SELECT * FROM users", map[string]interface{}{})
assert.NoError(s.T(), err)
responseArray, ok := liveResponse.([]interface{})
assert.True(s.T(), ok)
singleResponse := responseArray[0].(map[string]interface{})
liveIDStruct, ok := singleResponse["result"]
assert.True(s.T(), ok)
liveID := liveIDStruct.(string)

defer func() {
_, err = s.db.Kill(liveID)
s.Require().NoError(err)
}()

notifications, er := s.db.LiveNotifications(liveID)
// create a user
s.Require().NoError(er)
_, e := s.db.Create("users", map[string]interface{}{
"username": "johnny",
"password": "123",
})
s.Require().NoError(e)
notification := <-notifications
s.Require().Equal(model.CreateAction, notification.Action)
s.Require().Equal(liveID, notification.ID)
}

func (s *SurrealDBTestSuite) TestDelete() {
userData, err := s.db.Create("users", testUser{
Username: "johnny",
Expand Down
11 changes: 10 additions & 1 deletion internal/mock/mock.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package mock

import "github.com/surrealdb/surrealdb.go/pkg/websocket"
import (
"errors"

"github.com/surrealdb/surrealdb.go/pkg/model"
"github.com/surrealdb/surrealdb.go/pkg/websocket"
)

type ws struct {
}
Expand All @@ -17,6 +22,10 @@ func (w *ws) Close() error {
return nil
}

func (w *ws) LiveNotifications(id string) (chan model.Notification, error) {
return nil, errors.New("live queries are unimplemented for mocks")
}

func Create() *ws {
return &ws{}
}
168 changes: 151 additions & 17 deletions pkg/gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"sync"
"time"

"github.com/surrealdb/surrealdb.go/pkg/model"

gorilla "github.com/gorilla/websocket"
"github.com/surrealdb/surrealdb.go/internal/rpc"
"github.com/surrealdb/surrealdb.go/pkg/logger"
Expand Down Expand Up @@ -35,15 +39,19 @@ type WebSocket struct {
responseChannels map[string]chan rpc.RPCResponse
responseChannelsLock sync.RWMutex

notificationChannels map[string]chan model.Notification
notificationChannelsLock sync.RWMutex

close chan int
}

func Create() *WebSocket {
return &WebSocket{
Conn: nil,
close: make(chan int),
responseChannels: make(map[string]chan rpc.RPCResponse),
Timeout: DefaultTimeout * time.Second,
Conn: nil,
close: make(chan int),
responseChannels: make(map[string]chan rpc.RPCResponse),
notificationChannels: make(map[string]chan model.Notification),
Timeout: DefaultTimeout * time.Second,
}
}

Expand Down Expand Up @@ -103,6 +111,15 @@ func (ws *WebSocket) Close() error {
return ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
}

func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) {
c, err := ws.createNotificationChannel(liveQueryID)
if err != nil {
ws.logger.Logger.Err(err)
ws.logger.LogChannel <- err.Error()
}
return c, err
}

var (
ErrIDInUse = errors.New("id already in use")
ErrTimeout = errors.New("timeout")
Expand All @@ -123,6 +140,20 @@ func (ws *WebSocket) createResponseChannel(id string) (chan rpc.RPCResponse, err
return ch, nil
}

func (ws *WebSocket) createNotificationChannel(liveQueryID string) (chan model.Notification, error) {
ws.notificationChannelsLock.Lock()
defer ws.notificationChannelsLock.Unlock()

if _, ok := ws.notificationChannels[liveQueryID]; ok {
return nil, fmt.Errorf("%w: %v", ErrIDInUse, liveQueryID)
}

ch := make(chan model.Notification)
ws.notificationChannels[liveQueryID] = ch

return ch, nil
}

func (ws *WebSocket) removeResponseChannel(id string) {
ws.responseChannelsLock.Lock()
defer ws.responseChannelsLock.Unlock()
Expand All @@ -136,6 +167,13 @@ func (ws *WebSocket) getResponseChannel(id string) (chan rpc.RPCResponse, bool)
return ch, ok
}

func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) {
ws.notificationChannelsLock.RLock()
defer ws.notificationChannelsLock.RUnlock()
ch, ok := ws.notificationChannels[id]
return ch, ok
}

func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) {
id := rand.String(RequestIDLength)
request := &rpc.RPCRequest{
Expand All @@ -159,15 +197,16 @@ func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, err
select {
case <-timeout:
return nil, ErrTimeout
case res := <-responseChan:
case res, open := <-responseChan:
if !open {
return nil, errors.New("channel closed")
}
if res.ID != id {
return nil, ErrInvalidResponseID
}

if res.Error != nil {
return nil, res.Error
}

return res.Result, nil
}
}
Expand All @@ -177,7 +216,6 @@ func (ws *WebSocket) read(v interface{}) error {
if err != nil {
return err
}

return json.Unmarshal(data, v)
}

Expand Down Expand Up @@ -206,16 +244,112 @@ func (ws *WebSocket) initialize() {
ws.logger.LogChannel <- err.Error()
continue
}
responseChan, ok := ws.getResponseChannel(fmt.Sprintf("%v", res.ID))
if !ok {
err = errors.New("ResponseChannel is not ok")
ws.logger.Logger.Err(err)
ws.logger.LogChannel <- err.Error()
continue
}
responseChan <- res
close(responseChan)
go ws.handleResponse(res)
}
}
}()
}

func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {
if res.ID != nil && res.ID != "" {
// Try to resolve message as response to query
responseChan, ok := ws.getResponseChannel(fmt.Sprintf("%v", res.ID))
if !ok {
err := fmt.Errorf("unavailable ResponseChannel %+v", res.ID)
ws.logger.Logger.Err(err)
ws.logger.LogChannel <- err.Error()
return
}
defer close(responseChan)
responseChan <- res
} else {
// Try to resolve response as live query notification
mappedRes, _ := res.Result.(map[string]interface{})
resolvedID, ok := mappedRes["id"]
if !ok {
err := fmt.Errorf("response did not contain an 'id' field")
ws.logger.Logger.With().Str("result", fmt.Sprintf("%s", res.Result)).Err(err)
ws.logger.LogChannel <- err.Error()
return
}
var notification model.Notification
err := unmarshalMapToStruct(mappedRes, &notification)
if err != nil {
ws.logger.Logger.With().Str("result", fmt.Sprintf("%s", res.Result)).Err(err)
ws.logger.LogChannel <- err.Error()
return
}
LiveNotificationChan, ok := ws.getLiveChannel(notification.ID)
if !ok {
err := fmt.Errorf("unavailable ResponseChannel %+v", resolvedID)
ws.logger.Logger.Err(err)
ws.logger.LogChannel <- err.Error()
return
}
LiveNotificationChan <- notification
}
}

func unmarshalMapToStruct(data map[string]interface{}, outStruct interface{}) error {
outValue := reflect.ValueOf(outStruct)
if outValue.Kind() != reflect.Ptr || outValue.Elem().Kind() != reflect.Struct {
return fmt.Errorf("outStruct must be a pointer to a struct")
}

structValue := outValue.Elem()
structType := structValue.Type()

for i := 0; i < structValue.NumField(); i++ {
field := structType.Field(i)
fieldName := field.Name
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
fieldName = jsonTag
}
mapValue, ok := data[fieldName]
if !ok {
return fmt.Errorf("missing field in map: %s", fieldName)
}

fieldValue := structValue.Field(i)
if !fieldValue.CanSet() {
return fmt.Errorf("cannot set field: %s", fieldName)
}

if mapValue == nil {
// Handle nil values appropriately for your struct fields
// For simplicity, we skip nil values in this example
continue
}

// Type conversion based on the field type
switch fieldValue.Kind() {
case reflect.String:
fieldValue.SetString(fmt.Sprint(mapValue))
case reflect.Int:
intVal, err := strconv.Atoi(fmt.Sprint(mapValue))
if err != nil {
return err
}
fieldValue.SetInt(int64(intVal))
case reflect.Bool:
boolVal, err := strconv.ParseBool(fmt.Sprint(mapValue))
if err != nil {
return err
}
fieldValue.SetBool(boolVal)
case reflect.Map:
mapVal, ok := mapValue.(map[string]interface{})
if !ok {
return fmt.Errorf("mapValue for property %s is not a map[string]interface{}", fieldName)
}
fieldValue.Set(reflect.ValueOf(mapVal))

// Add cases for other types as needed
default:
return fmt.Errorf("unsupported field type: %s", fieldName)
}
}

return nil
}
15 changes: 15 additions & 0 deletions pkg/model/notification.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package model

type Notification struct {
ID string `json:"id"`
Action Action `json:"action"`
Result map[string]interface{} `json:"result"`
}

type Action string

const (
CreateAction Action = "CREATE"
UpdateAction Action = "UPDATE"
DeleteAction Action = "DELETE"
)
5 changes: 5 additions & 0 deletions pkg/websocket/websocket.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package websocket

import (
"github.com/surrealdb/surrealdb.go/pkg/model"
)

type WebSocket interface {
Connect(url string) (WebSocket, error)
Send(method string, params []interface{}) (interface{}, error)
Close() error
LiveNotifications(id string) (chan model.Notification, error)
}

0 comments on commit 78c52c6

Please sign in to comment.