Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package eventsource

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -32,6 +33,11 @@ type Stream struct {
isClosed bool
// isClosedMutex is a mutex protecting concurrent read/write access of isClosed
isClosedMutex sync.RWMutex
// streamWaitGroup drops to 0 when all the goroutines that process
// events have ended.
streamWaitGroup sync.WaitGroup
// cancelRequest cancels the context on the outbound request.
cancelRequest func()
}

type SubscriptionError struct {
Expand Down Expand Up @@ -62,21 +68,25 @@ func SubscribeWithRequest(lastEventId string, request *http.Request) (*Stream, e
// SubscribeWith takes a http client and request providing customization over both headers and
// control over the http client settings (timeouts, tls, etc)
func SubscribeWith(lastEventId string, client *http.Client, request *http.Request) (*Stream, error) {
ctx, cancelRequest := context.WithCancel(request.Context())
request = request.WithContext(ctx)

stream := &Stream{
c: client,
req: request,
lastEventId: lastEventId,
retry: time.Millisecond * 3000,
Events: make(chan Event),
Errors: make(chan error),
c: client,
req: request,
cancelRequest: cancelRequest,
lastEventId: lastEventId,
retry: time.Millisecond * 3000,
Events: make(chan Event),
Errors: make(chan error),
}
stream.c.CheckRedirect = checkRedirect

r, err := stream.connect()
if err != nil {
return nil, err
}
go stream.stream(r)
stream.stream(r)
return stream, nil
}

Expand All @@ -87,6 +97,13 @@ func (stream *Stream) Close() {
}

stream.markStreamClosed()

// Cancel the request and wait for the goroutine that processes the
// response to end. This ensures that nothing further will be written
// to the output channels.
stream.cancelRequest()
stream.streamWaitGroup.Wait()

close(stream.Errors)
close(stream.Events)
}
Expand Down Expand Up @@ -139,13 +156,23 @@ func (stream *Stream) connect() (r io.ReadCloser, err error) {
}

func (stream *Stream) stream(r io.ReadCloser) {
defer r.Close()
stream.streamWaitGroup.Add(1)

go func() {
defer stream.streamWaitGroup.Done()
defer r.Close()

// receives events until an error is encountered
stream.receiveEvents(r)
// receives events until an error is encountered
stream.receiveEvents(r)

// If the stream was closed, don't attempt to reconnect.
if stream.isStreamClosed() {
return
}

// tries to reconnect and start the stream again
stream.retryRestartStream()
// tries to reconnect and start the stream again
stream.retryRestartStream()
}()
}

func (stream *Stream) receiveEvents(r io.ReadCloser) {
Expand Down Expand Up @@ -187,7 +214,7 @@ func (stream *Stream) retryRestartStream() {
// but something to be aware of.
r, err := stream.connect()
if err == nil {
go stream.stream(r)
stream.stream(r)
return
}
stream.Errors <- err
Expand Down
73 changes: 73 additions & 0 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,79 @@ func TestStreamClose(t *testing.T) {
}
}

func TestStreamCloseWithEvents(t *testing.T) {
server := NewServer()
httpServer := httptest.NewServer(server.Handler(eventChannelName))
// The server has to be closed before the httpServer is closed.
// Otherwise the httpServer has still an open connection and it can not close.
defer httpServer.Close()
defer server.Close()

stream := mustSubscribe(t, httpServer.URL, "")

publishedEvent := &publication{id: "123"}
server.Publish([]string{eventChannelName}, publishedEvent)

time.Sleep(100 * time.Millisecond)

eventsC := drainEventChannel(stream.Events)

stream.Close()

select {
case receivedEvents := <-eventsC:
if len(receivedEvents) != 1 {
t.Fatalf("got %d events after close, want %d", len(receivedEvents), 1)
}

if !reflect.DeepEqual(receivedEvents[0], publishedEvent) {
t.Errorf("got event %+v, want %+v", receivedEvents[0], publishedEvent)
}
case <-time.After(timeToWaitForEvent):
t.Fatalf("Timed out waiting for stream.Events channel to close")
}
}

func drainEventChannel(c <-chan Event) <-chan []Event {
eventsC := make(chan []Event, 1)

go func() {
defer close(eventsC)

events := []Event{}
for event := range c {
events = append(events, event)
}

eventsC <- events
}()

return eventsC
}

func TestStreamCloseIsImmediate(t *testing.T) {
server := NewServer()
httpServer := httptest.NewServer(server.Handler(eventChannelName))
// The server has to be closed before the httpServer is closed.
// Otherwise the httpServer has still an open connection and it can not close.
defer httpServer.Close()
defer server.Close()

stream := mustSubscribe(t, httpServer.URL, "")

done := make(chan struct{})
go func() {
stream.Close()
close(done)
}()

select {
case <-done:
case <-time.After(time.Second):
t.Error("Timed out waiting for Close")
}
}

func mustSubscribe(t *testing.T, url, lastEventId string) *Stream {
stream, err := Subscribe(url, lastEventId)
if err != nil {
Expand Down