Skip to content

Commit eef8394

Browse files
chore: PR comments
- humanize some comments/remove lots of unneeded ones - refactor write - ensure we close?
1 parent f05d80c commit eef8394

File tree

2 files changed

+137
-43
lines changed

2 files changed

+137
-43
lines changed

compress_test.go

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ func TestWriteSingleFrameCompressed(t *testing.T) {
7474
var (
7575
flateThreshold = 64
7676

77-
largeMsg = []byte(strings.Repeat("hello world ", 100)) // ~1200 bytes, above threshold
78-
smallMsg = []byte("small message") // 13 bytes, below threshold
77+
largeMsg = []byte(strings.Repeat("hello world ", 100))
78+
smallMsg = []byte("small message")
7979
)
8080

8181
testCases := []struct {
@@ -182,15 +182,16 @@ func TestWriteThenWriterContextTakeover(t *testing.T) {
182182
}
183183
}()
184184

185-
// First message: Write() redirects flateWriter to temp buffer
186-
assert.Success(t, client.Write(ctx, MessageText, msg1))
185+
// 2. `Write` API
186+
err := client.Write(ctx, MessageText, msg1)
187+
assert.Success(t, err)
187188

188189
r := <-readCh
189190
assert.Success(t, r.err)
190191
assert.Equal(t, "msg1 type", MessageText, r.typ)
191192
assert.Equal(t, "msg1 content", string(msg1), string(r.p))
192193

193-
// Second message: Writer() streaming API
194+
// 2. `Writer` API
194195
w, err := client.Writer(ctx, MessageBinary)
195196
assert.Success(t, err)
196197
_, err = w.Write(msg2)
@@ -202,3 +203,92 @@ func TestWriteThenWriterContextTakeover(t *testing.T) {
202203
assert.Equal(t, "msg2 type", MessageBinary, r.typ)
203204
assert.Equal(t, "msg2 content", string(msg2), string(r.p))
204205
}
206+
207+
// TestCompressionDictionaryPreserved verifies that context takeover mode
208+
// preserves the compression dictionary across Conn.Write calls, resulting
209+
// in better compression for consecutive similar messages.
210+
func TestCompressionDictionaryPreserved(t *testing.T) {
211+
t.Parallel()
212+
213+
msg := []byte(strings.Repeat(`{"type":"event","data":"value"}`, 50))
214+
215+
// Test with context takeover
216+
clientConn1, serverConn1 := net.Pipe()
217+
defer clientConn1.Close()
218+
defer serverConn1.Close()
219+
220+
withTakeover := newConn(connConfig{
221+
rwc: clientConn1,
222+
client: true,
223+
copts: CompressionContextTakeover.opts(),
224+
flateThreshold: 64,
225+
br: bufio.NewReader(clientConn1),
226+
bw: bufio.NewWriterSize(clientConn1, 4096),
227+
})
228+
229+
// Test without context takeover
230+
clientConn2, serverConn2 := net.Pipe()
231+
defer clientConn2.Close()
232+
defer serverConn2.Close()
233+
234+
withoutTakeover := newConn(connConfig{
235+
rwc: clientConn2,
236+
client: true,
237+
copts: CompressionNoContextTakeover.opts(),
238+
flateThreshold: 64,
239+
br: bufio.NewReader(clientConn2),
240+
bw: bufio.NewWriterSize(clientConn2, 4096),
241+
})
242+
243+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
244+
defer cancel()
245+
246+
// Capture compressed sizes for both modes
247+
var withTakeoverSizes, withoutTakeoverSizes []int64
248+
249+
reader1 := bufio.NewReader(serverConn1)
250+
reader2 := bufio.NewReader(serverConn2)
251+
readBuf := make([]byte, 8)
252+
253+
// Send 3 identical messages each
254+
for range 3 {
255+
// With context takeover
256+
writeDone1 := make(chan error, 1)
257+
go func() {
258+
writeDone1 <- withTakeover.Write(ctx, MessageText, msg)
259+
}()
260+
261+
h1, err := readFrameHeader(reader1, readBuf)
262+
assert.Success(t, err)
263+
264+
_, err = io.CopyN(io.Discard, reader1, h1.payloadLength)
265+
assert.Success(t, err)
266+
267+
withTakeoverSizes = append(withTakeoverSizes, h1.payloadLength)
268+
assert.Success(t, <-writeDone1)
269+
270+
// Without context takeover
271+
writeDone2 := make(chan error, 1)
272+
go func() {
273+
writeDone2 <- withoutTakeover.Write(ctx, MessageText, msg)
274+
}()
275+
276+
h2, err := readFrameHeader(reader2, readBuf)
277+
assert.Success(t, err)
278+
279+
_, err = io.CopyN(io.Discard, reader2, h2.payloadLength)
280+
assert.Success(t, err)
281+
282+
withoutTakeoverSizes = append(withoutTakeoverSizes, h2.payloadLength)
283+
assert.Success(t, <-writeDone2)
284+
}
285+
286+
// With context takeover, the 2nd and 3rd messages should be smaller
287+
// than without context takeover (dictionary helps compress repeated patterns).
288+
// The first message will be similar size for both modes since there's no
289+
// prior dictionary. But subsequent messages benefit from context takeover.
290+
if withTakeoverSizes[2] >= withoutTakeoverSizes[2] {
291+
t.Errorf("context takeover should compress better: with=%d, without=%d",
292+
withTakeoverSizes[2], withoutTakeoverSizes[2])
293+
}
294+
}

write.go

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
101101
}
102102

103103
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
104-
_, err := c.writer(ctx, typ)
104+
err := c.msgWriter.reset(ctx, typ)
105105
if err != nil {
106106
return 0, err
107107
}
@@ -111,49 +111,13 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
111111
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
112112
}
113113

114-
// Below threshold: write uncompressed in single frame.
115114
if len(p) < c.flateThreshold {
116115
defer c.msgWriter.mu.unlock()
117116
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
118117
}
119118

120-
// Compress into buffer, then write as single frame.
121119
defer c.msgWriter.mu.unlock()
122-
123-
buf := bpool.Get()
124-
defer bpool.Put(buf)
125-
126-
c.msgWriter.ensureFlate()
127-
fw := c.msgWriter.flateWriter
128-
fw.Reset(buf)
129-
130-
_, err = fw.Write(p)
131-
if err != nil {
132-
return 0, fmt.Errorf("failed to compress: %w", err)
133-
}
134-
135-
err = fw.Flush()
136-
if err != nil {
137-
return 0, fmt.Errorf("failed to flush compression: %w", err)
138-
}
139-
140-
if !c.msgWriter.flateContextTakeover() {
141-
c.msgWriter.putFlateWriter()
142-
} else {
143-
// Restore flateWriter destination for subsequent Writer() API calls.
144-
fw.Reset(c.msgWriter.trimWriter)
145-
}
146-
147-
// Remove deflate tail bytes (last 4 bytes: \x00\x00\xff\xff).
148-
// See RFC 7692 section 7.2.1.
149-
compressed := buf.Bytes()
150-
compressed = compressed[:len(compressed)-4]
151-
152-
_, err = c.writeFrame(ctx, true, true, c.msgWriter.opcode, compressed)
153-
if err != nil {
154-
return 0, err
155-
}
156-
return len(p), nil
120+
return c.msgWriter.writeFull(ctx, p)
157121
}
158122

159123
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
@@ -179,6 +143,46 @@ func (mw *msgWriter) putFlateWriter() {
179143
}
180144
}
181145

146+
// writeFull compresses and writes p as a single frame.
147+
func (mw *msgWriter) writeFull(ctx context.Context, p []byte) (int, error) {
148+
mw.ensureFlate()
149+
150+
buf := bpool.Get()
151+
defer bpool.Put(buf)
152+
153+
// Buffer compressed output so we can write as a single frame instead
154+
// of chunked frames.
155+
origWriter := mw.trimWriter.w
156+
mw.trimWriter.w = buf
157+
defer func() {
158+
mw.trimWriter.w = origWriter
159+
}()
160+
161+
_, err := mw.flateWriter.Write(p)
162+
if err != nil {
163+
return 0, fmt.Errorf("failed to compress: %w", err)
164+
}
165+
166+
err = mw.flateWriter.Flush()
167+
if err != nil {
168+
return 0, fmt.Errorf("failed to flush compression: %w", err)
169+
}
170+
171+
mw.trimWriter.reset()
172+
173+
if !mw.flateContextTakeover() {
174+
mw.putFlateWriter()
175+
}
176+
177+
mw.closed = true
178+
179+
_, err = mw.c.writeFrame(ctx, true, true, mw.opcode, buf.Bytes())
180+
if err != nil {
181+
return 0, err
182+
}
183+
return len(p), nil
184+
}
185+
182186
// Write writes the given bytes to the WebSocket connection.
183187
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
184188
err = mw.writeMu.lock(mw.ctx)

0 commit comments

Comments
 (0)