Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
package websocket

import (
"bufio"
"bytes"
"compress/flate"
"context"
"io"
"net"
"strings"
"testing"
"time"

"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
Expand Down Expand Up @@ -59,3 +63,232 @@ func BenchmarkFlateReader(b *testing.B) {
io.ReadAll(r)
}
}

// TestWriteSingleFrameCompressed verifies that Conn.Write sends compressed
// messages in a single frame instead of multiple frames, and that messages
// below the flateThreshold are sent uncompressed.
// This is a regression test for https://github.com/coder/websocket/issues/435
func TestWriteSingleFrameCompressed(t *testing.T) {
t.Parallel()

var (
flateThreshold = 64

largeMsg = []byte(strings.Repeat("hello world ", 100))
smallMsg = []byte("small message")
)

testCases := []struct {
name string
mode CompressionMode
msg []byte
wantRsv1 bool // true = compressed, false = uncompressed
}{
{"ContextTakeover/AboveThreshold", CompressionContextTakeover, largeMsg, true},
{"NoContextTakeover/AboveThreshold", CompressionNoContextTakeover, largeMsg, true},
{"ContextTakeover/BelowThreshold", CompressionContextTakeover, smallMsg, false},
{"NoContextTakeover/BelowThreshold", CompressionNoContextTakeover, smallMsg, false},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()

c := newConn(connConfig{
rwc: clientConn,
client: true,
copts: tc.mode.opts(),
flateThreshold: flateThreshold,
br: bufio.NewReader(clientConn),
bw: bufio.NewWriterSize(clientConn, 4096),
})

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()

writeDone := make(chan error, 1)
go func() {
writeDone <- c.Write(ctx, MessageText, tc.msg)
}()

reader := bufio.NewReader(serverConn)
readBuf := make([]byte, 8)

h, err := readFrameHeader(reader, readBuf)
assert.Success(t, err)

_, err = io.CopyN(io.Discard, reader, h.payloadLength)
assert.Success(t, err)

assert.Equal(t, "opcode", opText, h.opcode)
assert.Equal(t, "rsv1 (compressed)", tc.wantRsv1, h.rsv1)
assert.Equal(t, "fin", true, h.fin)

err = <-writeDone
assert.Success(t, err)
})
}
}

// TestWriteThenWriterContextTakeover verifies that using Conn.Write followed by
// Conn.Writer works correctly with context takeover enabled. This tests that
// the flateWriter destination is properly restored after Conn.Write redirects
// it to a temporary buffer.
func TestWriteThenWriterContextTakeover(t *testing.T) {
t.Parallel()

clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()

client := newConn(connConfig{
rwc: clientConn,
client: true,
copts: CompressionContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(clientConn),
bw: bufio.NewWriterSize(clientConn, 4096),
})

server := newConn(connConfig{
rwc: serverConn,
client: false,
copts: CompressionContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(serverConn),
bw: bufio.NewWriterSize(serverConn, 4096),
})

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()

msg1 := []byte(strings.Repeat("first message ", 100))
msg2 := []byte(strings.Repeat("second message ", 100))

type readResult struct {
typ MessageType
p []byte
err error
}
readCh := make(chan readResult, 2)
go func() {
for range 2 {
typ, p, err := server.Read(ctx)
readCh <- readResult{typ, p, err}
}
}()

// 2. `Write` API
err := client.Write(ctx, MessageText, msg1)
assert.Success(t, err)

r := <-readCh
assert.Success(t, r.err)
assert.Equal(t, "msg1 type", MessageText, r.typ)
assert.Equal(t, "msg1 content", string(msg1), string(r.p))

// 2. `Writer` API
w, err := client.Writer(ctx, MessageBinary)
assert.Success(t, err)
_, err = w.Write(msg2)
assert.Success(t, err)
assert.Success(t, w.Close())

r = <-readCh
assert.Success(t, r.err)
assert.Equal(t, "msg2 type", MessageBinary, r.typ)
assert.Equal(t, "msg2 content", string(msg2), string(r.p))
}

// TestCompressionDictionaryPreserved verifies that context takeover mode
// preserves the compression dictionary across Conn.Write calls, resulting
// in better compression for consecutive similar messages.
func TestCompressionDictionaryPreserved(t *testing.T) {
t.Parallel()

msg := []byte(strings.Repeat(`{"type":"event","data":"value"}`, 50))

// Test with context takeover
clientConn1, serverConn1 := net.Pipe()
defer clientConn1.Close()
defer serverConn1.Close()

withTakeover := newConn(connConfig{
rwc: clientConn1,
client: true,
copts: CompressionContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(clientConn1),
bw: bufio.NewWriterSize(clientConn1, 4096),
})

// Test without context takeover
clientConn2, serverConn2 := net.Pipe()
defer clientConn2.Close()
defer serverConn2.Close()

withoutTakeover := newConn(connConfig{
rwc: clientConn2,
client: true,
copts: CompressionNoContextTakeover.opts(),
flateThreshold: 64,
br: bufio.NewReader(clientConn2),
bw: bufio.NewWriterSize(clientConn2, 4096),
})

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

// Capture compressed sizes for both modes
var withTakeoverSizes, withoutTakeoverSizes []int64

reader1 := bufio.NewReader(serverConn1)
reader2 := bufio.NewReader(serverConn2)
readBuf := make([]byte, 8)

// Send 3 identical messages each
for range 3 {
// With context takeover
writeDone1 := make(chan error, 1)
go func() {
writeDone1 <- withTakeover.Write(ctx, MessageText, msg)
}()

h1, err := readFrameHeader(reader1, readBuf)
assert.Success(t, err)

_, err = io.CopyN(io.Discard, reader1, h1.payloadLength)
assert.Success(t, err)

withTakeoverSizes = append(withTakeoverSizes, h1.payloadLength)
assert.Success(t, <-writeDone1)

// Without context takeover
writeDone2 := make(chan error, 1)
go func() {
writeDone2 <- withoutTakeover.Write(ctx, MessageText, msg)
}()

h2, err := readFrameHeader(reader2, readBuf)
assert.Success(t, err)

_, err = io.CopyN(io.Discard, reader2, h2.payloadLength)
assert.Success(t, err)

withoutTakeoverSizes = append(withoutTakeoverSizes, h2.payloadLength)
assert.Success(t, <-writeDone2)
}

// With context takeover, the 2nd and 3rd messages should be smaller
// than without context takeover (dictionary helps compress repeated patterns).
// The first message will be similar size for both modes since there's no
// prior dictionary. But subsequent messages benefit from context takeover.
if withTakeoverSizes[2] >= withoutTakeoverSizes[2] {
t.Errorf("context takeover should compress better: with=%d, without=%d",
withTakeoverSizes[2], withoutTakeoverSizes[2])
}
}
53 changes: 47 additions & 6 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net"
"time"

"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
Expand Down Expand Up @@ -100,7 +101,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return 0, err
}
Expand All @@ -110,13 +111,13 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}

n, err := mw.Write(p)
if err != nil {
return n, err
if len(p) < c.flateThreshold {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}

err = mw.Close()
return n, err
defer c.msgWriter.mu.unlock()
return c.msgWriter.writeFull(ctx, p)
}

func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
Expand All @@ -142,6 +143,46 @@ func (mw *msgWriter) putFlateWriter() {
}
}

// writeFull compresses and writes p as a single frame.
func (mw *msgWriter) writeFull(ctx context.Context, p []byte) (int, error) {
mw.ensureFlate()

buf := bpool.Get()
defer bpool.Put(buf)

// Buffer compressed output so we can write as a single frame instead
// of chunked frames.
origWriter := mw.trimWriter.w
mw.trimWriter.w = buf
defer func() {
mw.trimWriter.w = origWriter
}()

_, err := mw.flateWriter.Write(p)
if err != nil {
return 0, fmt.Errorf("failed to compress: %w", err)
}

err = mw.flateWriter.Flush()
if err != nil {
return 0, fmt.Errorf("failed to flush compression: %w", err)
}

mw.trimWriter.reset()

if !mw.flateContextTakeover() {
mw.putFlateWriter()
}

mw.closed = true

_, err = mw.c.writeFrame(ctx, true, true, mw.opcode, buf.Bytes())
if err != nil {
return 0, err
}
return len(p), nil
}

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
Expand Down
Loading