diff --git a/stream.go b/stream.go index 5970877..86cd4db 100644 --- a/stream.go +++ b/stream.go @@ -28,10 +28,8 @@ type Stream struct { Errors chan error // Logger is a logger that, when set, will be used for logging debug messages Logger *log.Logger - // isClosed is a marker that the stream is/should be closed - isClosed bool - // isClosedMutex is a mutex protecting concurrent read/write access of isClosed - isClosedMutex sync.RWMutex + closeWhenClosed chan struct{} + closeOnce sync.Once } type SubscriptionError struct { @@ -69,6 +67,7 @@ func SubscribeWith(lastEventId string, client *http.Client, request *http.Reques retry: time.Millisecond * 3000, Events: make(chan Event), Errors: make(chan error), + closeWhenClosed: make(chan struct{}), } stream.c.CheckRedirect = checkRedirect @@ -82,25 +81,9 @@ func SubscribeWith(lastEventId string, client *http.Client, request *http.Reques // Close will close the stream. It is safe for concurrent access and can be called multiple times. func (stream *Stream) Close() { - if stream.isStreamClosed() { - return - } - - stream.markStreamClosed() - close(stream.Errors) - close(stream.Events) -} - -func (stream *Stream) isStreamClosed() bool { - stream.isClosedMutex.RLock() - defer stream.isClosedMutex.RUnlock() - return stream.isClosed -} - -func (stream *Stream) markStreamClosed() { - stream.isClosedMutex.Lock() - defer stream.isClosedMutex.Unlock() - stream.isClosed = true + stream.closeOnce.Do(func() { + close(stream.closeWhenClosed) + }) } // Go's http package doesn't copy headers across when it encounters @@ -141,56 +124,65 @@ func (stream *Stream) connect() (r io.ReadCloser, err error) { func (stream *Stream) stream(r io.ReadCloser) { defer r.Close() - // receives events until an error is encountered - stream.receiveEvents(r) - - // tries to reconnect and start the stream again - stream.retryRestartStream() -} - -func (stream *Stream) receiveEvents(r io.ReadCloser) { - dec := NewDecoder(r) - for { - ev, err := dec.Decode() - if stream.isStreamClosed() { - return - } - if err != nil { - stream.Errors <- err - return - } - - pub := ev.(*publication) - if pub.Retry() > 0 { - stream.retry = time.Duration(pub.Retry()) * time.Millisecond - } - if len(pub.Id()) > 0 { - stream.lastEventId = pub.Id() - } - stream.Events <- ev - } -} - -func (stream *Stream) retryRestartStream() { - backoff := stream.retry - for { - if stream.Logger != nil { - stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds()) - } - time.Sleep(backoff) - if stream.isStreamClosed() { - return - } - // NOTE: because of the defer we're opening the new connection - // before closing the old one. Shouldn't be a problem in practice, - // but something to be aware of. - r, err := stream.connect() - if err == nil { - go stream.stream(r) - return + dec := NewDecoder(r) + + events := make(chan Event) + errors := make(chan error) + + go func() { + for { + ev, err := dec.Decode() + + if err != nil { + errors <- err + close(events) + return + } else { + events <- ev + } + } + }() + + for { + select { + case err := <-errors: + stream.Errors <- err + case ev, ok := <-events: + if !ok { + // tries to reconnect and start the stream again + backoff := stream.retry + for { + if stream.Logger != nil { + stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds()) + } + time.Sleep(backoff) + // NOTE: because of the defer we're opening the new connection + // before closing the old one. Shouldn't be a problem in practice, + // but something to be aware of. + _, err := stream.connect() + if err == nil { + break + } + stream.Errors <- err + backoff *= 2 + } + } + pub := ev.(*publication) + if pub.Retry() > 0 { + stream.retry = time.Duration(pub.Retry()) * time.Millisecond + } + if len(pub.Id()) > 0 { + stream.lastEventId = pub.Id() + } + stream.Events <- ev + case _, ok := <-stream.closeWhenClosed: + if !ok { + close(stream.Errors) + close(stream.Events) + return + } + } } - stream.Errors <- err - backoff *= 2 } }