Skip to content

Commit

Permalink
Merge pull request #3142 from Ayrtat/fix/wsclient_hang
Browse files Browse the repository at this point in the history
rpcclient: fix wsclient hang on making request
  • Loading branch information
roman-khimov authored Oct 16, 2023
2 parents 1aabdc9 + ff19189 commit 38f77c3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
5 changes: 3 additions & 2 deletions pkg/rpcclient/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func NewInternal(ctx context.Context, register InternalHook) (*Internal, error)
Client: Client{},

shutdown: make(chan struct{}),
done: make(chan struct{}),
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
subscriptions: make(map[string]notificationReceiver),
receivers: make(map[any][]string),
},
Expand Down Expand Up @@ -63,7 +64,7 @@ eventloop:
c.notifySubscribers(ntf)
}
}
close(c.done)
close(c.readerDone)
c.ctxCancel()
// ctx is cancelled, server is notified and will finish soon.
drainloop:
Expand Down
41 changes: 29 additions & 12 deletions pkg/rpcclient/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ type WSClient struct {

ws *websocket.Conn
wsOpts WSOptions
done chan struct{}
readerDone chan struct{}
writerDone chan struct{}
requests chan *neorpc.Request
shutdown chan struct{}
closeCalled atomic.Bool
Expand Down Expand Up @@ -425,7 +426,8 @@ func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, err
ws: ws,
wsOpts: opts,
shutdown: make(chan struct{}),
done: make(chan struct{}),
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
respChannels: make(map[uint64]chan *neorpc.Response),
requests: make(chan *neorpc.Request),
subscriptions: make(map[string]notificationReceiver),
Expand Down Expand Up @@ -457,7 +459,7 @@ func (c *WSClient) Close() {
// Call to cancel will send signal to all users of Context().
c.Client.ctxCancel()
}
<-c.done
<-c.readerDone
}

func (c *WSClient) wsReader() {
Expand Down Expand Up @@ -551,7 +553,7 @@ readloop:
if connCloseErr != nil {
c.setCloseErr(connCloseErr)
}
close(c.done)
close(c.readerDone)
c.respLock.Lock()
for _, ch := range c.respChannels {
close(ch)
Expand Down Expand Up @@ -583,13 +585,14 @@ func (c *WSClient) wsWriter() {
pingTicker := time.NewTicker(wsPingPeriod)
defer c.ws.Close()
defer pingTicker.Stop()
defer close(c.writerDone)
var connCloseErr error
writeloop:
for {
select {
case <-c.shutdown:
return
case <-c.done:
case <-c.readerDone:
return
case req, ok := <-c.requests:
if !ok {
Expand Down Expand Up @@ -660,28 +663,42 @@ func (c *WSClient) getResponseChannel(id uint64) chan *neorpc.Response {
return c.respChannels[id]
}

// closeErrOrConnLost returns the error that may occur either in wsReader or wsWriter.
// If wsReader or wsWriter do not set the error, it returns ErrWSConnLost.
func (c *WSClient) closeErrOrConnLost() (err error) {
err = ErrWSConnLost
if closeErr := c.GetError(); closeErr != nil {
err = closeErr
}
return
}

func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
ch := make(chan *neorpc.Response)
c.respLock.Lock()
select {
case <-c.done:
case <-c.readerDone:
c.respLock.Unlock()
return nil, fmt.Errorf("%w: before registering response channel", ErrWSConnLost)
return nil, fmt.Errorf("%w: before registering response channel", c.closeErrOrConnLost())
default:
c.respChannels[r.ID] = ch
c.respLock.Unlock()
}
select {
case <-c.done:
return nil, fmt.Errorf("%w: before sending the request", ErrWSConnLost)
case <-c.readerDone:
return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost())
case <-c.writerDone:
return nil, fmt.Errorf("%w: before sending the request", c.closeErrOrConnLost())
case c.requests <- r:
}
select {
case <-c.done:
return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost)
case <-c.readerDone:
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
case <-c.writerDone:
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
case resp, ok := <-ch:
if !ok {
return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost)
return nil, fmt.Errorf("%w: while waiting for the response", c.closeErrOrConnLost())
}
c.unregisterRespChannel(r.ID)
return resp, nil
Expand Down

0 comments on commit 38f77c3

Please sign in to comment.