Skip to content

Commit 8d3545a

Browse files
fix: transmit in single frame when compression enabled (#552)
Closes #435 When compression was enabled, `Conn.Write` sent messages across many small frames due to the flate library's internal `bufferFlushSize` (240 bytes). Each flush triggered a `writeFrame` call, producing alternating ~236 and 4 byte frames. `Conn.Write` now compresses the entire message into a buffer first, then transmits it as a single frame. Messages below `flateThreshold` bypass compression and are sent uncompressed in a single frame.
1 parent 8bf6dd2 commit 8d3545a

File tree

2 files changed

+305
-10
lines changed

2 files changed

+305
-10
lines changed

compress_test.go

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
package websocket
44

55
import (
6+
"bufio"
67
"bytes"
78
"compress/flate"
9+
"context"
810
"io"
11+
"net"
912
"strings"
1013
"testing"
14+
"time"
1115

1216
"github.com/coder/websocket/internal/test/assert"
1317
"github.com/coder/websocket/internal/test/xrand"
@@ -59,3 +63,249 @@ func BenchmarkFlateReader(b *testing.B) {
5963
io.ReadAll(r)
6064
}
6165
}
66+
67+
// TestWriteSingleFrameCompressed verifies that Conn.Write sends compressed
68+
// messages in a single frame instead of multiple frames, and that messages
69+
// below the flateThreshold are sent uncompressed.
70+
// This is a regression test for https://github.com/coder/websocket/issues/435
71+
func TestWriteSingleFrameCompressed(t *testing.T) {
72+
t.Parallel()
73+
74+
var (
75+
flateThreshold = 64
76+
77+
largeMsg = []byte(strings.Repeat("hello world ", 100))
78+
smallMsg = []byte("small message")
79+
)
80+
81+
testCases := []struct {
82+
name string
83+
mode CompressionMode
84+
msg []byte
85+
wantRsv1 bool // true = compressed, false = uncompressed
86+
}{
87+
{"ContextTakeover/AboveThreshold", CompressionContextTakeover, largeMsg, true},
88+
{"NoContextTakeover/AboveThreshold", CompressionNoContextTakeover, largeMsg, true},
89+
{"ContextTakeover/BelowThreshold", CompressionContextTakeover, smallMsg, false},
90+
{"NoContextTakeover/BelowThreshold", CompressionNoContextTakeover, smallMsg, false},
91+
}
92+
93+
for _, tc := range testCases {
94+
t.Run(tc.name, func(t *testing.T) {
95+
t.Parallel()
96+
97+
clientConn, serverConn := net.Pipe()
98+
defer clientConn.Close()
99+
defer serverConn.Close()
100+
101+
c := newConn(connConfig{
102+
rwc: clientConn,
103+
client: true,
104+
copts: tc.mode.opts(),
105+
flateThreshold: flateThreshold,
106+
br: bufio.NewReader(clientConn),
107+
bw: bufio.NewWriterSize(clientConn, 4096),
108+
})
109+
110+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
111+
defer cancel()
112+
113+
writeDone := make(chan error, 1)
114+
go func() {
115+
writeDone <- c.Write(ctx, MessageText, tc.msg)
116+
}()
117+
118+
reader := bufio.NewReader(serverConn)
119+
readBuf := make([]byte, 8)
120+
121+
h, err := readFrameHeader(reader, readBuf)
122+
assert.Success(t, err)
123+
124+
_, err = io.CopyN(io.Discard, reader, h.payloadLength)
125+
assert.Success(t, err)
126+
127+
assert.Equal(t, "opcode", opText, h.opcode)
128+
assert.Equal(t, "rsv1 (compressed)", tc.wantRsv1, h.rsv1)
129+
assert.Equal(t, "fin", true, h.fin)
130+
131+
err = <-writeDone
132+
assert.Success(t, err)
133+
})
134+
}
135+
}
136+
137+
// TestWriteThenWriterContextTakeover verifies that using Conn.Write followed by
138+
// Conn.Writer works correctly with context takeover enabled. This tests that
139+
// the flateWriter destination is properly restored after Conn.Write redirects
140+
// it to a temporary buffer.
141+
func TestWriteThenWriterContextTakeover(t *testing.T) {
142+
t.Parallel()
143+
144+
clientConn, serverConn := net.Pipe()
145+
defer clientConn.Close()
146+
defer serverConn.Close()
147+
148+
client := newConn(connConfig{
149+
rwc: clientConn,
150+
client: true,
151+
copts: CompressionContextTakeover.opts(),
152+
flateThreshold: 64,
153+
br: bufio.NewReader(clientConn),
154+
bw: bufio.NewWriterSize(clientConn, 4096),
155+
})
156+
157+
server := newConn(connConfig{
158+
rwc: serverConn,
159+
client: false,
160+
copts: CompressionContextTakeover.opts(),
161+
flateThreshold: 64,
162+
br: bufio.NewReader(serverConn),
163+
bw: bufio.NewWriterSize(serverConn, 4096),
164+
})
165+
166+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
167+
defer cancel()
168+
169+
msg1 := []byte(strings.Repeat("first message ", 100))
170+
msg2 := []byte(strings.Repeat("second message ", 100))
171+
172+
type readResult struct {
173+
typ MessageType
174+
p []byte
175+
err error
176+
}
177+
readCh := make(chan readResult, 3)
178+
go func() {
179+
for range 3 {
180+
typ, p, err := server.Read(ctx)
181+
readCh <- readResult{typ, p, err}
182+
}
183+
}()
184+
185+
// We want to verify mixing `Write` and `Writer` usages still work.
186+
//
187+
// To this end, we call them in this order:
188+
// - `Write`
189+
// - `Writer`
190+
// - `Write`
191+
//
192+
// This verifies that it works for a `Write` followed by a `Writer`
193+
// as well as a `Writer` followed by a `Write`.
194+
195+
// 1. `Write` API
196+
err := client.Write(ctx, MessageText, msg1)
197+
assert.Success(t, err)
198+
199+
r := <-readCh
200+
assert.Success(t, r.err)
201+
assert.Equal(t, "Write type", MessageText, r.typ)
202+
assert.Equal(t, "Write content", string(msg1), string(r.p))
203+
204+
// 2. `Writer` API
205+
w, err := client.Writer(ctx, MessageBinary)
206+
assert.Success(t, err)
207+
_, err = w.Write(msg2)
208+
assert.Success(t, err)
209+
assert.Success(t, w.Close())
210+
211+
r = <-readCh
212+
assert.Success(t, r.err)
213+
assert.Equal(t, "Writer type", MessageBinary, r.typ)
214+
assert.Equal(t, "Writer content", string(msg2), string(r.p))
215+
216+
// 3. `Write` API again
217+
err = client.Write(ctx, MessageText, msg1)
218+
assert.Success(t, err)
219+
220+
r = <-readCh
221+
assert.Success(t, r.err)
222+
assert.Equal(t, "Write type", MessageText, r.typ)
223+
assert.Equal(t, "Write content", string(msg1), string(r.p))
224+
}
225+
226+
// TestCompressionDictionaryPreserved verifies that context takeover mode
227+
// preserves the compression dictionary across Conn.Write calls, resulting
228+
// in better compression for consecutive similar messages.
229+
func TestCompressionDictionaryPreserved(t *testing.T) {
230+
t.Parallel()
231+
232+
msg := []byte(strings.Repeat(`{"type":"event","data":"value"}`, 50))
233+
234+
takeoverClient, takeoverServer := net.Pipe()
235+
defer takeoverClient.Close()
236+
defer takeoverServer.Close()
237+
238+
withTakeover := newConn(connConfig{
239+
rwc: takeoverClient,
240+
client: true,
241+
copts: CompressionContextTakeover.opts(),
242+
flateThreshold: 64,
243+
br: bufio.NewReader(takeoverClient),
244+
bw: bufio.NewWriterSize(takeoverClient, 4096),
245+
})
246+
247+
noTakeoverClient, noTakeoverServer := net.Pipe()
248+
defer noTakeoverClient.Close()
249+
defer noTakeoverServer.Close()
250+
251+
withoutTakeover := newConn(connConfig{
252+
rwc: noTakeoverClient,
253+
client: true,
254+
copts: CompressionNoContextTakeover.opts(),
255+
flateThreshold: 64,
256+
br: bufio.NewReader(noTakeoverClient),
257+
bw: bufio.NewWriterSize(noTakeoverClient, 4096),
258+
})
259+
260+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
261+
defer cancel()
262+
263+
// Capture compressed sizes for both modes
264+
var withTakeoverSizes, withoutTakeoverSizes []int64
265+
266+
reader1 := bufio.NewReader(takeoverServer)
267+
reader2 := bufio.NewReader(noTakeoverServer)
268+
readBuf := make([]byte, 8)
269+
270+
// Send 3 identical messages each
271+
for range 3 {
272+
// With context takeover
273+
writeDone1 := make(chan error, 1)
274+
go func() {
275+
writeDone1 <- withTakeover.Write(ctx, MessageText, msg)
276+
}()
277+
278+
h1, err := readFrameHeader(reader1, readBuf)
279+
assert.Success(t, err)
280+
281+
_, err = io.CopyN(io.Discard, reader1, h1.payloadLength)
282+
assert.Success(t, err)
283+
284+
withTakeoverSizes = append(withTakeoverSizes, h1.payloadLength)
285+
assert.Success(t, <-writeDone1)
286+
287+
// Without context takeover
288+
writeDone2 := make(chan error, 1)
289+
go func() {
290+
writeDone2 <- withoutTakeover.Write(ctx, MessageText, msg)
291+
}()
292+
293+
h2, err := readFrameHeader(reader2, readBuf)
294+
assert.Success(t, err)
295+
296+
_, err = io.CopyN(io.Discard, reader2, h2.payloadLength)
297+
assert.Success(t, err)
298+
299+
withoutTakeoverSizes = append(withoutTakeoverSizes, h2.payloadLength)
300+
assert.Success(t, <-writeDone2)
301+
}
302+
303+
// With context takeover, the 2nd and 3rd messages should be smaller than
304+
// without context takeover (dictionary helps compress repeated patterns).
305+
// The first message will be similar size for both modes since there's no
306+
// prior dictionary. But subsequent messages benefit from context takeover.
307+
if withTakeoverSizes[2] >= withoutTakeoverSizes[2] {
308+
t.Errorf("context takeover should compress better: with=%d, without=%d",
309+
withTakeoverSizes[2], withoutTakeoverSizes[2])
310+
}
311+
}

write.go

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"net"
1515
"time"
1616

17+
"github.com/coder/websocket/internal/bpool"
1718
"github.com/coder/websocket/internal/errd"
1819
"github.com/coder/websocket/internal/util"
1920
)
@@ -100,23 +101,17 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
100101
}
101102

102103
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
103-
mw, err := c.writer(ctx, typ)
104+
err := c.msgWriter.reset(ctx, typ)
104105
if err != nil {
105106
return 0, err
106107
}
108+
defer c.msgWriter.mu.unlock()
107109

108-
if !c.flate() {
109-
defer c.msgWriter.mu.unlock()
110+
if !c.flate() || len(p) < c.flateThreshold {
110111
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
111112
}
112113

113-
n, err := mw.Write(p)
114-
if err != nil {
115-
return n, err
116-
}
117-
118-
err = mw.Close()
119-
return n, err
114+
return c.msgWriter.writeCompressedFrame(ctx, p)
120115
}
121116

122117
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
@@ -142,6 +137,56 @@ func (mw *msgWriter) putFlateWriter() {
142137
}
143138
}
144139

140+
// writeCompressedFrame compresses and writes p as a single frame.
141+
func (mw *msgWriter) writeCompressedFrame(ctx context.Context, p []byte) (int, error) {
142+
err := mw.writeMu.lock(mw.ctx)
143+
if err != nil {
144+
return 0, fmt.Errorf("failed to write: %w", err)
145+
}
146+
defer mw.writeMu.unlock()
147+
148+
if mw.closed {
149+
return 0, errors.New("cannot use closed writer")
150+
}
151+
152+
mw.ensureFlate()
153+
154+
buf := bpool.Get()
155+
defer bpool.Put(buf)
156+
157+
// Buffer compressed output so we can write as
158+
// a single frame instead of chunked frames.
159+
origWriter := mw.trimWriter.w
160+
mw.trimWriter.w = buf
161+
defer func() {
162+
mw.trimWriter.w = origWriter
163+
}()
164+
165+
_, err = mw.flateWriter.Write(p)
166+
if err != nil {
167+
return 0, fmt.Errorf("failed to compress: %w", err)
168+
}
169+
170+
err = mw.flateWriter.Flush()
171+
if err != nil {
172+
return 0, fmt.Errorf("failed to flush compression: %w", err)
173+
}
174+
175+
mw.trimWriter.reset()
176+
177+
if !mw.flateContextTakeover() {
178+
mw.putFlateWriter()
179+
}
180+
181+
mw.closed = true
182+
183+
_, err = mw.c.writeFrame(ctx, true, true, mw.opcode, buf.Bytes())
184+
if err != nil {
185+
return 0, err
186+
}
187+
return len(p), nil
188+
}
189+
145190
// Write writes the given bytes to the WebSocket connection.
146191
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
147192
err = mw.writeMu.lock(mw.ctx)

0 commit comments

Comments
 (0)