Skip to content

Commit a1483a5

Browse files
authored
Preserve type for binary protocol parameters in prepared statements (#1081)
1 parent 884cd37 commit a1483a5

File tree

4 files changed

+63
-1
lines changed

4 files changed

+63
-1
lines changed

client/stmt.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ func (s *Stmt) write(args ...interface{}) error {
156156
case []byte:
157157
paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING}
158158
paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...)
159+
case mysql.TypedBytes:
160+
paramTypes[i] = []byte{v.Type}
161+
paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v.Bytes))), v.Bytes...)
159162
case json.RawMessage:
160163
paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING}
161164
paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...)

mysql/typed_bytes.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package mysql
2+
3+
// TypedBytes preserves the original MySQL type alongside the raw bytes
4+
// for binary protocol parameters that are length-encoded.
5+
type TypedBytes struct {
6+
Type byte // Original MySQL type
7+
Bytes []byte // Raw bytes
8+
}

server/stmt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte)
300300
}
301301

302302
if !isNull {
303-
args[i] = v
303+
args[i] = mysql.TypedBytes{Type: tp, Bytes: v}
304304
continue
305305
} else {
306306
args[i] = nil

server/stmt_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,54 @@ func TestStmtPrepareWithPreparedStmt(t *testing.T) {
9797
require.NoError(t, err)
9898
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type)
9999
}
100+
101+
func TestBindStmtArgsTypedBytes(t *testing.T) {
102+
testcases := []struct {
103+
name string
104+
paramType byte
105+
paramValue []byte
106+
expectType byte
107+
expectBytes []byte
108+
}{
109+
{
110+
name: "DATETIME",
111+
paramType: mysql.MYSQL_TYPE_DATETIME,
112+
paramValue: []byte{0x07, 0xe8, 0x07, 0x06, 0x0f, 0x0e, 0x1e, 0x2d},
113+
expectType: mysql.MYSQL_TYPE_DATETIME,
114+
expectBytes: []byte{0xe8, 0x07, 0x06, 0x0f, 0x0e, 0x1e, 0x2d},
115+
},
116+
{
117+
name: "VARCHAR",
118+
paramType: mysql.MYSQL_TYPE_VARCHAR,
119+
paramValue: []byte{0x05, 'h', 'e', 'l', 'l', 'o'},
120+
expectType: mysql.MYSQL_TYPE_VARCHAR,
121+
expectBytes: []byte("hello"),
122+
},
123+
{
124+
name: "BLOB",
125+
paramType: mysql.MYSQL_TYPE_BLOB,
126+
paramValue: []byte{0x04, 0x00, 0x01, 0x02, 0x03},
127+
expectType: mysql.MYSQL_TYPE_BLOB,
128+
expectBytes: []byte{0x00, 0x01, 0x02, 0x03},
129+
},
130+
}
131+
132+
for _, tc := range testcases {
133+
t.Run(tc.name, func(t *testing.T) {
134+
c := &Conn{}
135+
s := &Stmt{Args: make([]interface{}, 1)}
136+
s.Params = 1
137+
138+
nullBitmap := []byte{0x00}
139+
paramTypes := []byte{tc.paramType, 0x00}
140+
141+
err := c.bindStmtArgs(s, nullBitmap, paramTypes, tc.paramValue)
142+
require.NoError(t, err)
143+
144+
tv, ok := s.Args[0].(mysql.TypedBytes)
145+
require.True(t, ok, "expected TypedBytes, got %T", s.Args[0])
146+
require.Equal(t, tc.expectType, tv.Type)
147+
require.Equal(t, tc.expectBytes, tv.Bytes)
148+
})
149+
}
150+
}

0 commit comments

Comments
 (0)