Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
package websocket

import (
"bufio"
"bytes"
"compress/flate"
"context"
"io"
"net"
"strings"
"testing"
"time"

"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
Expand Down Expand Up @@ -59,3 +63,249 @@ func BenchmarkFlateReader(b *testing.B) {
io.ReadAll(r)
}
}

// TestWriteSingleFrameCompressed verifies that Conn.Write sends compressed
// messages in a single frame instead of multiple frames, and that messages
// below the flateThreshold are sent uncompressed.
// This is a regression test for https://github.com/coder/websocket/issues/435
func TestWriteSingleFrameCompressed(t *testing.T) {
t.Parallel()

var (
flateThreshold = 64

largeMsg = []byte(strings.Repeat("hello world ", 100))
smallMsg = []byte("small message")
)

testCases := []struct {
name string
mode CompressionMode
msg []byte
wantRsv1 bool // true = compressed, false = uncompressed
}{
{"ContextTakeover/AboveThreshold", CompressionContextTakeover, largeMsg, true},
{"NoContextTakeover/AboveThreshold", CompressionNoContextTakeover, largeMsg, true},
{"ContextTakeover/BelowThreshold", CompressionContextTakeover, smallMsg, false},
{"NoContextTakeover/BelowThreshold", CompressionNoContextTakeover, smallMsg, false},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()

c := newConn(connConfig{
rwc: clientConn,
client: true,
copts: tc.mode.opts(),
flateThreshold: flateThreshold,
br: bufio.NewReader(clientConn),
bw: bufio.NewWriterSize(clientConn, 4096),
})

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()

writeDone := make(chan error, 1)
go func() {
writeDone <- c.Write(ctx, MessageText, tc.msg)
}()

reader := bufio.NewReader(serverConn)
readBuf := make([]byte, 8)

h, err := readFrameHeader(reader, readBuf)
assert.Success(t, err)

_, err = io.CopyN(io.Discard, reader, h.payloadLength)
assert.Success(t, err)

assert.Equal(t, "opcode", opText, h.opcode)
assert.Equal(t, "rsv1 (compressed)", tc.wantRsv1, h.rsv1)
assert.Equal(t, "fin", true, h.fin)

err = <-writeDone
assert.Success(t, err)
})
}
}

// TestWriteThenWriterContextTakeover verifies that using Conn.Write followed by
// Conn.Writer works correctly with context takeover enabled. This tests that
// the flateWriter destination is properly restored after Conn.Write redirects
// it to a temporary buffer.
func TestWriteThenWriterContextTakeover(t *testing.T) {
t.Parallel()

clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()

client := newConn(connConfig{
rwc: clientConn,
client: true,
copts: CompressionContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(clientConn),
bw: bufio.NewWriterSize(clientConn, 4096),
})

server := newConn(connConfig{
rwc: serverConn,
client: false,
copts: CompressionContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(serverConn),
bw: bufio.NewWriterSize(serverConn, 4096),
})

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()

msg1 := []byte(strings.Repeat("first message ", 100))
msg2 := []byte(strings.Repeat("second message ", 100))

type readResult struct {
typ MessageType
p []byte
err error
}
readCh := make(chan readResult, 3)
go func() {
for range 3 {
typ, p, err := server.Read(ctx)
readCh <- readResult{typ, p, err}
}
}()

// We want to verify mixing `Write` and `Writer` usages still work.
//
// To this end, we call them in this order:
// - `Write`
// - `Writer`
// - `Write`
//
// This verifies that it works for a `Write` followed by a `Writer`
// as well as a `Writer` followed by a `Write`.

// 1. `Write` API
err := client.Write(ctx, MessageText, msg1)
assert.Success(t, err)

r := <-readCh
assert.Success(t, r.err)
assert.Equal(t, "Write type", MessageText, r.typ)
assert.Equal(t, "Write content", string(msg1), string(r.p))

// 2. `Writer` API
w, err := client.Writer(ctx, MessageBinary)
assert.Success(t, err)
_, err = w.Write(msg2)
assert.Success(t, err)
assert.Success(t, w.Close())

r = <-readCh
assert.Success(t, r.err)
assert.Equal(t, "Writer type", MessageBinary, r.typ)
assert.Equal(t, "Writer content", string(msg2), string(r.p))

// 3. `Write` API again
err = client.Write(ctx, MessageText, msg1)
assert.Success(t, err)

r = <-readCh
assert.Success(t, r.err)
assert.Equal(t, "Write type", MessageText, r.typ)
assert.Equal(t, "Write content", string(msg1), string(r.p))
}

// TestCompressionDictionaryPreserved verifies that context takeover mode
// preserves the compression dictionary across Conn.Write calls, resulting
// in better compression for consecutive similar messages.
func TestCompressionDictionaryPreserved(t *testing.T) {
t.Parallel()

msg := []byte(strings.Repeat(`{"type":"event","data":"value"}`, 50))

takeoverClient, takeoverServer := net.Pipe()
defer takeoverClient.Close()
defer takeoverServer.Close()

withTakeover := newConn(connConfig{
rwc: takeoverClient,
client: true,
copts: CompressionContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(takeoverClient),
bw: bufio.NewWriterSize(takeoverClient, 4096),
})

noTakeoverClient, noTakeoverServer := net.Pipe()
defer noTakeoverClient.Close()
defer noTakeoverServer.Close()

withoutTakeover := newConn(connConfig{
rwc: noTakeoverClient,
client: true,
copts: CompressionNoContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(noTakeoverClient),
bw: bufio.NewWriterSize(noTakeoverClient, 4096),
})

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

// Capture compressed sizes for both modes
var withTakeoverSizes, withoutTakeoverSizes []int64

reader1 := bufio.NewReader(takeoverServer)
reader2 := bufio.NewReader(noTakeoverServer)
readBuf := make([]byte, 8)

// Send 3 identical messages each
for range 3 {
// With context takeover
writeDone1 := make(chan error, 1)
go func() {
writeDone1 <- withTakeover.Write(ctx, MessageText, msg)
}()

h1, err := readFrameHeader(reader1, readBuf)
assert.Success(t, err)

_, err = io.CopyN(io.Discard, reader1, h1.payloadLength)
assert.Success(t, err)

withTakeoverSizes = append(withTakeoverSizes, h1.payloadLength)
assert.Success(t, <-writeDone1)

// Without context takeover
writeDone2 := make(chan error, 1)
go func() {
writeDone2 <- withoutTakeover.Write(ctx, MessageText, msg)
}()

h2, err := readFrameHeader(reader2, readBuf)
assert.Success(t, err)

_, err = io.CopyN(io.Discard, reader2, h2.payloadLength)
assert.Success(t, err)

withoutTakeoverSizes = append(withoutTakeoverSizes, h2.payloadLength)
assert.Success(t, <-writeDone2)
}

// With context takeover, the 2nd and 3rd messages should be smaller
// than without context takeover (dictionary helps compress repeated patterns).
// The first message will be similar size for both modes since there's no
// prior dictionary. But subsequent messages benefit from context takeover.
if withTakeoverSizes[2] >= withoutTakeoverSizes[2] {
t.Errorf("context takeover should compress better: with=%d, without=%d",
withTakeoverSizes[2], withoutTakeoverSizes[2])
}
}
64 changes: 55 additions & 9 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net"
"time"

"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
Expand Down Expand Up @@ -100,23 +101,18 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return 0, err
}

if !c.flate() {
if !c.flate() || len(p) < c.flateThreshold {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}

n, err := mw.Write(p)
if err != nil {
return n, err
}

err = mw.Close()
return n, err
defer c.msgWriter.mu.unlock()
return c.msgWriter.writeCompressedFrame(ctx, p)
}

func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
Expand All @@ -142,6 +138,56 @@ func (mw *msgWriter) putFlateWriter() {
}
}

// writeCompressedFrame compresses and writes p as a single frame.
func (mw *msgWriter) writeCompressedFrame(ctx context.Context, p []byte) (int, error) {
err := mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()

if mw.closed {
return 0, errors.New("cannot use closed writer")
}

mw.ensureFlate()

buf := bpool.Get()
defer bpool.Put(buf)

// Buffer compressed output so we can write as a single frame instead
// of chunked frames.
origWriter := mw.trimWriter.w
mw.trimWriter.w = buf
defer func() {
mw.trimWriter.w = origWriter
}()

_, err = mw.flateWriter.Write(p)
if err != nil {
return 0, fmt.Errorf("failed to compress: %w", err)
}

err = mw.flateWriter.Flush()
if err != nil {
return 0, fmt.Errorf("failed to flush compression: %w", err)
}

mw.trimWriter.reset()

if !mw.flateContextTakeover() {
mw.putFlateWriter()
}

mw.closed = true

_, err = mw.c.writeFrame(ctx, true, true, mw.opcode, buf.Bytes())
if err != nil {
return 0, err
}
return len(p), nil
}

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
Expand Down
Loading