Skip to content

Commit

Permalink
Merge pull request #4 from qntfy/expected-stop-and-tests
Browse files Browse the repository at this point in the history
 Add stream mutex to prevent data race
  • Loading branch information
JoshuaC215 authored Feb 7, 2019
2 parents b1f5777 + 0cd5ed7 commit d85a731
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
10 changes: 9 additions & 1 deletion twitter/streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type Stream struct {
group *sync.WaitGroup
body io.Closer
expected bool
mutex sync.Mutex
}

// newStream creates a Stream and starts a goroutine to retry connecting and
Expand All @@ -176,6 +177,7 @@ func newStream(client *http.Client, req *http.Request, expBackoff, aggExpBackoff
// Stop signals retry and receiver to stop, closes the Messages channel, and
// blocks until done.
func (s *Stream) Stop() {
s.mutex.Lock()
s.expected = true
close(s.done)
// Scanner does not have a Stop() or take a done channel, so for low volume
Expand All @@ -184,14 +186,18 @@ func (s *Stream) Stop() {
if s.body != nil {
s.body.Close()
}
s.mutex.Unlock()
// block until the retry goroutine stops
s.group.Wait()
}

// ExpectedStop indicates whether Stream halting was due to an expected Stop()
// or some error condition.
func (s *Stream) ExpectedStop() bool {
return s.expected
s.mutex.Lock()
result := s.expected
s.mutex.Unlock()
return result
}

// retry retries making the given http.Request and receiving the response
Expand All @@ -213,7 +219,9 @@ func (s *Stream) retry(req *http.Request, expBackOff backoff.BackOff, aggExpBack
}
// when err is nil, resp contains a non-nil Body which must be closed
defer resp.Body.Close()
s.mutex.Lock()
s.body = resp.Body
s.mutex.Unlock()
switch resp.StatusCode {
case 200:
// receive stream response Body, handles closing
Expand Down
6 changes: 5 additions & 1 deletion twitter/streams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ func TestStream_Filter(t *testing.T) {
stream, err := client.Streams.Filter(streamFilterParams)
// assert that the expected messages are received
assert.NoError(t, err)
defer stream.Stop()
for message := range stream.Messages {
demux.Handle(message)
}
expectedCounts := &counter{all: 3, other: 3}
assert.Equal(t, expectedCounts, counts)

// test ExpectedStop
assert.False(t, stream.ExpectedStop())
stream.Stop()
assert.True(t, stream.ExpectedStop())
}

func TestStream_Sample(t *testing.T) {
Expand Down

0 comments on commit d85a731

Please sign in to comment.