Skip to content

Commit f05d80c

Browse files
fix: transmit in single frame when compression enabled
Closes #435 ### Problem When compression was enabled, `Conn.Write` sent messages across many small frames due to the flate library's internal `bufferFlushSize` (240 bytes). Each flush triggered a `writeFrame` call, producing alternating ~236 and 4 byte frames. This broke clients like Unreal Engine that only process one frame per tick. ### Solution `Conn.Write` now compresses the entire message into a buffer first, then transmits it as a single frame. Messages below `flateThreshold` bypass compression and are sent uncompressed in a single frame. For `CompressionContextTakeover` mode, the flateWriter destination is restored after buffered compression to ensure subsequent `Writer()` streaming calls work correctly. ### Changes - **write.go**: Buffer compressed output before transmission - **compress_test.go**: Added regression tests for single-frame behavior and Write/Writer interop
1 parent 8bf6dd2 commit f05d80c

File tree

2 files changed

+185
-5
lines changed

2 files changed

+185
-5
lines changed

compress_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
package websocket
44

55
import (
6+
"bufio"
67
"bytes"
78
"compress/flate"
9+
"context"
810
"io"
11+
"net"
912
"strings"
1013
"testing"
14+
"time"
1115

1216
"github.com/coder/websocket/internal/test/assert"
1317
"github.com/coder/websocket/internal/test/xrand"
@@ -59,3 +63,142 @@ func BenchmarkFlateReader(b *testing.B) {
5963
io.ReadAll(r)
6064
}
6165
}
66+
67+
// TestWriteSingleFrameCompressed verifies that Conn.Write sends compressed
68+
// messages in a single frame instead of multiple frames, and that messages
69+
// below the flateThreshold are sent uncompressed.
70+
// This is a regression test for https://github.com/coder/websocket/issues/435
71+
func TestWriteSingleFrameCompressed(t *testing.T) {
72+
t.Parallel()
73+
74+
var (
75+
flateThreshold = 64
76+
77+
largeMsg = []byte(strings.Repeat("hello world ", 100)) // ~1200 bytes, above threshold
78+
smallMsg = []byte("small message") // 13 bytes, below threshold
79+
)
80+
81+
testCases := []struct {
82+
name string
83+
mode CompressionMode
84+
msg []byte
85+
wantRsv1 bool // true = compressed, false = uncompressed
86+
}{
87+
{"ContextTakeover/AboveThreshold", CompressionContextTakeover, largeMsg, true},
88+
{"NoContextTakeover/AboveThreshold", CompressionNoContextTakeover, largeMsg, true},
89+
{"ContextTakeover/BelowThreshold", CompressionContextTakeover, smallMsg, false},
90+
{"NoContextTakeover/BelowThreshold", CompressionNoContextTakeover, smallMsg, false},
91+
}
92+
93+
for _, tc := range testCases {
94+
t.Run(tc.name, func(t *testing.T) {
95+
t.Parallel()
96+
97+
clientConn, serverConn := net.Pipe()
98+
defer clientConn.Close()
99+
defer serverConn.Close()
100+
101+
c := newConn(connConfig{
102+
rwc: clientConn,
103+
client: true,
104+
copts: tc.mode.opts(),
105+
flateThreshold: flateThreshold,
106+
br: bufio.NewReader(clientConn),
107+
bw: bufio.NewWriterSize(clientConn, 4096),
108+
})
109+
110+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
111+
defer cancel()
112+
113+
writeDone := make(chan error, 1)
114+
go func() {
115+
writeDone <- c.Write(ctx, MessageText, tc.msg)
116+
}()
117+
118+
reader := bufio.NewReader(serverConn)
119+
readBuf := make([]byte, 8)
120+
121+
h, err := readFrameHeader(reader, readBuf)
122+
assert.Success(t, err)
123+
124+
_, err = io.CopyN(io.Discard, reader, h.payloadLength)
125+
assert.Success(t, err)
126+
127+
assert.Equal(t, "opcode", opText, h.opcode)
128+
assert.Equal(t, "rsv1 (compressed)", tc.wantRsv1, h.rsv1)
129+
assert.Equal(t, "fin", true, h.fin)
130+
131+
err = <-writeDone
132+
assert.Success(t, err)
133+
})
134+
}
135+
}
136+
137+
// TestWriteThenWriterContextTakeover verifies that using Conn.Write followed by
138+
// Conn.Writer works correctly with context takeover enabled. This tests that
139+
// the flateWriter destination is properly restored after Conn.Write redirects
140+
// it to a temporary buffer.
141+
func TestWriteThenWriterContextTakeover(t *testing.T) {
142+
t.Parallel()
143+
144+
clientConn, serverConn := net.Pipe()
145+
defer clientConn.Close()
146+
defer serverConn.Close()
147+
148+
client := newConn(connConfig{
149+
rwc: clientConn,
150+
client: true,
151+
copts: CompressionContextTakeover.opts(),
152+
flateThreshold: 64,
153+
br: bufio.NewReader(clientConn),
154+
bw: bufio.NewWriterSize(clientConn, 4096),
155+
})
156+
157+
server := newConn(connConfig{
158+
rwc: serverConn,
159+
client: false,
160+
copts: CompressionContextTakeover.opts(),
161+
flateThreshold: 64,
162+
br: bufio.NewReader(serverConn),
163+
bw: bufio.NewWriterSize(serverConn, 4096),
164+
})
165+
166+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
167+
defer cancel()
168+
169+
msg1 := []byte(strings.Repeat("first message ", 100))
170+
msg2 := []byte(strings.Repeat("second message ", 100))
171+
172+
type readResult struct {
173+
typ MessageType
174+
p []byte
175+
err error
176+
}
177+
readCh := make(chan readResult, 2)
178+
go func() {
179+
for range 2 {
180+
typ, p, err := server.Read(ctx)
181+
readCh <- readResult{typ, p, err}
182+
}
183+
}()
184+
185+
// First message: Write() redirects flateWriter to temp buffer
186+
assert.Success(t, client.Write(ctx, MessageText, msg1))
187+
188+
r := <-readCh
189+
assert.Success(t, r.err)
190+
assert.Equal(t, "msg1 type", MessageText, r.typ)
191+
assert.Equal(t, "msg1 content", string(msg1), string(r.p))
192+
193+
// Second message: Writer() streaming API
194+
w, err := client.Writer(ctx, MessageBinary)
195+
assert.Success(t, err)
196+
_, err = w.Write(msg2)
197+
assert.Success(t, err)
198+
assert.Success(t, w.Close())
199+
200+
r = <-readCh
201+
assert.Success(t, r.err)
202+
assert.Equal(t, "msg2 type", MessageBinary, r.typ)
203+
assert.Equal(t, "msg2 content", string(msg2), string(r.p))
204+
}

write.go

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"net"
1515
"time"
1616

17+
"github.com/coder/websocket/internal/bpool"
1718
"github.com/coder/websocket/internal/errd"
1819
"github.com/coder/websocket/internal/util"
1920
)
@@ -100,7 +101,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
100101
}
101102

102103
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
103-
mw, err := c.writer(ctx, typ)
104+
_, err := c.writer(ctx, typ)
104105
if err != nil {
105106
return 0, err
106107
}
@@ -110,13 +111,49 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
110111
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
111112
}
112113

113-
n, err := mw.Write(p)
114+
// Below threshold: write uncompressed in single frame.
115+
if len(p) < c.flateThreshold {
116+
defer c.msgWriter.mu.unlock()
117+
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
118+
}
119+
120+
// Compress into buffer, then write as single frame.
121+
defer c.msgWriter.mu.unlock()
122+
123+
buf := bpool.Get()
124+
defer bpool.Put(buf)
125+
126+
c.msgWriter.ensureFlate()
127+
fw := c.msgWriter.flateWriter
128+
fw.Reset(buf)
129+
130+
_, err = fw.Write(p)
114131
if err != nil {
115-
return n, err
132+
return 0, fmt.Errorf("failed to compress: %w", err)
116133
}
117134

118-
err = mw.Close()
119-
return n, err
135+
err = fw.Flush()
136+
if err != nil {
137+
return 0, fmt.Errorf("failed to flush compression: %w", err)
138+
}
139+
140+
if !c.msgWriter.flateContextTakeover() {
141+
c.msgWriter.putFlateWriter()
142+
} else {
143+
// Restore flateWriter destination for subsequent Writer() API calls.
144+
fw.Reset(c.msgWriter.trimWriter)
145+
}
146+
147+
// Remove deflate tail bytes (last 4 bytes: \x00\x00\xff\xff).
148+
// See RFC 7692 section 7.2.1.
149+
compressed := buf.Bytes()
150+
compressed = compressed[:len(compressed)-4]
151+
152+
_, err = c.writeFrame(ctx, true, true, c.msgWriter.opcode, compressed)
153+
if err != nil {
154+
return 0, err
155+
}
156+
return len(p), nil
120157
}
121158

122159
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {

0 commit comments

Comments
 (0)