@@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
750750 )
751751 }
752752
753+ const minPktLen = 4 + 1 + 4 + 1 + 4
753754 mc := stmt .mc
754755
755756 // Reset packet-sequence
@@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
758759 var data []byte
759760
760761 if len (args ) == 0 {
761- data = mc .buf .takeBuffer (4 + 1 + 4 + 1 + 4 )
762+ data = mc .buf .takeBuffer (minPktLen )
762763 } else {
763764 data = mc .buf .takeCompleteBuffer ()
764765 }
@@ -787,34 +788,50 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
787788 data [13 ] = 0x00
788789
789790 if len (args ) > 0 {
790- // NULL-bitmap [(len(args)+7)/8 bytes]
791- nullMask := uint64 (0 )
792-
793- pos := 4 + 1 + 4 + 1 + 4 + ((len (args ) + 7 ) >> 3 )
791+ pos := minPktLen
792+
793+ var nullMask []byte
794+ if maskLen , typesLen := (len (args )+ 7 )/ 8 , 1 + 2 * len (args ); pos + maskLen + typesLen >= len (data ) {
795+ // buffer has to be extended but we don't know by how much so
796+ // we depend on append after all data with known sizes fit.
797+ // We stop at that because we deal with a lot of columns here
798+ // which makes the required allocation size hard to guess.
799+ tmp := make ([]byte , pos + maskLen + typesLen )
800+ copy (tmp [:pos ], data [:pos ])
801+ data = tmp
802+ nullMask = data [pos : pos + maskLen ]
803+ pos += maskLen
804+ } else {
805+ nullMask = data [pos : pos + maskLen ]
806+ for i := 0 ; i < maskLen ; i ++ {
807+ nullMask [i ] = 0
808+ }
809+ pos += maskLen
810+ }
794811
795812 // newParameterBoundFlag 1 [1 byte]
796813 data [pos ] = 0x01
797814 pos ++
798815
799816 // type of each parameter [len(args)*2 bytes]
800817 paramTypes := data [pos :]
801- pos += ( len (args ) << 1 )
818+ pos += len (args ) * 2
802819
803820 // value of each parameter [n bytes]
804821 paramValues := data [pos :pos ]
805822 valuesCap := cap (paramValues )
806823
807- for i := range args {
824+ for i , arg := range args {
808825 // build NULL-bitmap
809- if args [ i ] == nil {
810- nullMask |= 1 << uint (i )
826+ if arg == nil {
827+ nullMask [ i / 8 ] |= 1 << ( uint (i ) & 7 )
811828 paramTypes [i + i ] = fieldTypeNULL
812829 paramTypes [i + i + 1 ] = 0x00
813830 continue
814831 }
815832
816833 // cache types and values
817- switch v := args [ i ] .(type ) {
834+ switch v := arg .(type ) {
818835 case int64 :
819836 paramTypes [i + i ] = fieldTypeLongLong
820837 paramTypes [i + i + 1 ] = 0x00
@@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
877894 }
878895
879896 // Handle []byte(nil) as a NULL value
880- nullMask |= 1 << uint (i )
897+ nullMask [ i / 8 ] |= 1 << ( uint (i ) & 7 )
881898 paramTypes [i + i ] = fieldTypeNULL
882899 paramTypes [i + i + 1 ] = 0x00
883900
@@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
913930 paramValues = append (paramValues , val ... )
914931
915932 default :
916- return fmt .Errorf ("Can't convert type: %T" , args [ i ] )
933+ return fmt .Errorf ("Can't convert type: %T" , arg )
917934 }
918935 }
919936
@@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
926943
927944 pos += len (paramValues )
928945 data = data [:pos ]
929-
930- // Convert nullMask to bytes
931- for i , max := 0 , (stmt .paramCount + 7 )>> 3 ; i < max ; i ++ {
932- data [i + 14 ] = byte (nullMask >> uint (i << 3 ))
933- }
934946 }
935947
936948 return mc .writePacket (data )
0 commit comments