diff --git a/twitter/streams.go b/twitter/streams.go index 9d1e386..6a96ad6 100644 --- a/twitter/streams.go +++ b/twitter/streams.go @@ -156,6 +156,7 @@ type Stream struct { group *sync.WaitGroup body io.Closer expected bool + mutex sync.Mutex } // newStream creates a Stream and starts a goroutine to retry connecting and @@ -176,6 +177,7 @@ func newStream(client *http.Client, req *http.Request, expBackoff, aggExpBackoff // Stop signals retry and receiver to stop, closes the Messages channel, and // blocks until done. func (s *Stream) Stop() { + s.mutex.Lock() s.expected = true close(s.done) // Scanner does not have a Stop() or take a done channel, so for low volume @@ -184,6 +186,7 @@ func (s *Stream) Stop() { if s.body != nil { s.body.Close() } + s.mutex.Unlock() // block until the retry goroutine stops s.group.Wait() } @@ -191,7 +194,10 @@ func (s *Stream) Stop() { // ExpectedStop indicates whether Stream halting was due to an expected Stop() // or some error condition. func (s *Stream) ExpectedStop() bool { - return s.expected + s.mutex.Lock() + result := s.expected + s.mutex.Unlock() + return result } // retry retries making the given http.Request and receiving the response @@ -213,7 +219,9 @@ func (s *Stream) retry(req *http.Request, expBackOff backoff.BackOff, aggExpBack } // when err is nil, resp contains a non-nil Body which must be closed defer resp.Body.Close() + s.mutex.Lock() s.body = resp.Body + s.mutex.Unlock() switch resp.StatusCode { case 200: // receive stream response Body, handles closing diff --git a/twitter/streams_test.go b/twitter/streams_test.go index 80b3a5e..f46daf8 100644 --- a/twitter/streams_test.go +++ b/twitter/streams_test.go @@ -128,12 +128,16 @@ func TestStream_Filter(t *testing.T) { stream, err := client.Streams.Filter(streamFilterParams) // assert that the expected messages are received assert.NoError(t, err) - defer stream.Stop() for message := range stream.Messages { demux.Handle(message) } expectedCounts := &counter{all: 3, other: 3} assert.Equal(t, expectedCounts, counts) + + // test ExpectedStop + assert.False(t, stream.ExpectedStop()) + stream.Stop() + assert.True(t, stream.ExpectedStop()) } func TestStream_Sample(t *testing.T) {