Skip to content

Commit 885bd0d

Browse files
chore: address PR comments
- compression_test.go: add a `Writer` -> `Write` transition - compression_test.go: Claude sucks at variable names apparently - write.go: make sure to hold a lock 🤦 - write.go: combine two identical if branches - write.go: rename `writeFull` because the name was bad
1 parent eef8394 commit 885bd0d

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

compress_test.go

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,32 @@ func TestWriteThenWriterContextTakeover(t *testing.T) {
174174
p []byte
175175
err error
176176
}
177-
readCh := make(chan readResult, 2)
177+
readCh := make(chan readResult, 3)
178178
go func() {
179-
for range 2 {
179+
for range 3 {
180180
typ, p, err := server.Read(ctx)
181181
readCh <- readResult{typ, p, err}
182182
}
183183
}()
184184

185-
// 2. `Write` API
185+
// We want to verify mixing `Write` and `Writer` usages still work.
186+
//
187+
// To this end, we call them in this order:
188+
// - `Write`
189+
// - `Writer`
190+
// - `Write`
191+
//
192+
// This verifies that it works for a `Write` followed by a `Writer`
193+
// as well as a `Writer` followed by a `Write`.
194+
195+
// 1. `Write` API
186196
err := client.Write(ctx, MessageText, msg1)
187197
assert.Success(t, err)
188198

189199
r := <-readCh
190200
assert.Success(t, r.err)
191-
assert.Equal(t, "msg1 type", MessageText, r.typ)
192-
assert.Equal(t, "msg1 content", string(msg1), string(r.p))
201+
assert.Equal(t, "Write type", MessageText, r.typ)
202+
assert.Equal(t, "Write content", string(msg1), string(r.p))
193203

194204
// 2. `Writer` API
195205
w, err := client.Writer(ctx, MessageBinary)
@@ -200,8 +210,17 @@ func TestWriteThenWriterContextTakeover(t *testing.T) {
200210

201211
r = <-readCh
202212
assert.Success(t, r.err)
203-
assert.Equal(t, "msg2 type", MessageBinary, r.typ)
204-
assert.Equal(t, "msg2 content", string(msg2), string(r.p))
213+
assert.Equal(t, "Writer type", MessageBinary, r.typ)
214+
assert.Equal(t, "Writer content", string(msg2), string(r.p))
215+
216+
// 3. `Write` API again
217+
err = client.Write(ctx, MessageText, msg1)
218+
assert.Success(t, err)
219+
220+
r = <-readCh
221+
assert.Success(t, r.err)
222+
assert.Equal(t, "Write type", MessageText, r.typ)
223+
assert.Equal(t, "Write content", string(msg1), string(r.p))
205224
}
206225

207226
// TestCompressionDictionaryPreserved verifies that context takeover mode
@@ -212,32 +231,30 @@ func TestCompressionDictionaryPreserved(t *testing.T) {
212231

213232
msg := []byte(strings.Repeat(`{"type":"event","data":"value"}`, 50))
214233

215-
// Test with context takeover
216-
clientConn1, serverConn1 := net.Pipe()
217-
defer clientConn1.Close()
218-
defer serverConn1.Close()
234+
takeoverClient, takeoverServer := net.Pipe()
235+
defer takeoverClient.Close()
236+
defer takeoverServer.Close()
219237

220238
withTakeover := newConn(connConfig{
221-
rwc: clientConn1,
239+
rwc: takeoverClient,
222240
client: true,
223241
copts: CompressionContextTakeover.opts(),
224242
flateThreshold: 64,
225-
br: bufio.NewReader(clientConn1),
226-
bw: bufio.NewWriterSize(clientConn1, 4096),
243+
br: bufio.NewReader(takeoverClient),
244+
bw: bufio.NewWriterSize(takeoverClient, 4096),
227245
})
228246

229-
// Test without context takeover
230-
clientConn2, serverConn2 := net.Pipe()
231-
defer clientConn2.Close()
232-
defer serverConn2.Close()
247+
noTakeoverClient, noTakeoverServer := net.Pipe()
248+
defer noTakeoverClient.Close()
249+
defer noTakeoverServer.Close()
233250

234251
withoutTakeover := newConn(connConfig{
235-
rwc: clientConn2,
252+
rwc: noTakeoverClient,
236253
client: true,
237254
copts: CompressionNoContextTakeover.opts(),
238255
flateThreshold: 64,
239-
br: bufio.NewReader(clientConn2),
240-
bw: bufio.NewWriterSize(clientConn2, 4096),
256+
br: bufio.NewReader(noTakeoverClient),
257+
bw: bufio.NewWriterSize(noTakeoverClient, 4096),
241258
})
242259

243260
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
@@ -246,8 +263,8 @@ func TestCompressionDictionaryPreserved(t *testing.T) {
246263
// Capture compressed sizes for both modes
247264
var withTakeoverSizes, withoutTakeoverSizes []int64
248265

249-
reader1 := bufio.NewReader(serverConn1)
250-
reader2 := bufio.NewReader(serverConn2)
266+
reader1 := bufio.NewReader(takeoverServer)
267+
reader2 := bufio.NewReader(noTakeoverServer)
251268
readBuf := make([]byte, 8)
252269

253270
// Send 3 identical messages each

write.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,13 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
106106
return 0, err
107107
}
108108

109-
if !c.flate() {
110-
defer c.msgWriter.mu.unlock()
111-
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
112-
}
113-
114-
if len(p) < c.flateThreshold {
109+
if !c.flate() || len(p) < c.flateThreshold {
115110
defer c.msgWriter.mu.unlock()
116111
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
117112
}
118113

119114
defer c.msgWriter.mu.unlock()
120-
return c.msgWriter.writeFull(ctx, p)
115+
return c.msgWriter.writeCompressedFrame(ctx, p)
121116
}
122117

123118
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
@@ -143,8 +138,18 @@ func (mw *msgWriter) putFlateWriter() {
143138
}
144139
}
145140

146-
// writeFull compresses and writes p as a single frame.
147-
func (mw *msgWriter) writeFull(ctx context.Context, p []byte) (int, error) {
141+
// writeCompressedFrame compresses and writes p as a single frame.
142+
func (mw *msgWriter) writeCompressedFrame(ctx context.Context, p []byte) (int, error) {
143+
err := mw.writeMu.lock(mw.ctx)
144+
if err != nil {
145+
return 0, fmt.Errorf("failed to write: %w", err)
146+
}
147+
defer mw.writeMu.unlock()
148+
149+
if mw.closed {
150+
return 0, errors.New("cannot use closed writer")
151+
}
152+
148153
mw.ensureFlate()
149154

150155
buf := bpool.Get()
@@ -158,7 +163,7 @@ func (mw *msgWriter) writeFull(ctx context.Context, p []byte) (int, error) {
158163
mw.trimWriter.w = origWriter
159164
}()
160165

161-
_, err := mw.flateWriter.Write(p)
166+
_, err = mw.flateWriter.Write(p)
162167
if err != nil {
163168
return 0, fmt.Errorf("failed to compress: %w", err)
164169
}

0 commit comments

Comments
 (0)