99package mysql
1010
1111import (
12- "database/sql/driver"
1312 "fmt"
1413 "io"
1514 "os"
2120 readerRegister map [string ]func () io.Reader
2221)
2322
24- func init () {
25- fileRegister = make (map [string ]bool )
26- readerRegister = make (map [string ]func () io.Reader )
27- }
28-
2923// RegisterLocalFile adds the given file to the file whitelist,
3024// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
3125// Alternatively you can allow the use of all local files with
@@ -38,6 +32,11 @@ func init() {
3832// ...
3933//
4034func RegisterLocalFile (filePath string ) {
35+ // lazy map init
36+ if fileRegister == nil {
37+ fileRegister = make (map [string ]bool )
38+ }
39+
4140 fileRegister [strings .Trim (filePath , `"` )] = true
4241}
4342
@@ -62,6 +61,11 @@ func DeregisterLocalFile(filePath string) {
6261// ...
6362//
6463func RegisterReaderHandler (name string , handler func () io.Reader ) {
64+ // lazy map init
65+ if readerRegister == nil {
66+ readerRegister = make (map [string ]func () io.Reader )
67+ }
68+
6569 readerRegister [name ] = handler
6670}
6771
@@ -71,71 +75,81 @@ func DeregisterReaderHandler(name string) {
7175 delete (readerRegister , name )
7276}
7377
78+ func deferredClose (err * error , closer io.Closer ) {
79+ closeErr := closer .Close ()
80+ if * err == nil {
81+ * err = closeErr
82+ }
83+ }
84+
7485func (mc * mysqlConn ) handleInFileRequest (name string ) (err error ) {
7586 var rdr io.Reader
76- data := make ( []byte , 4 + mc . maxWriteSize )
87+ var data []byte
7788
7889 if strings .HasPrefix (name , "Reader::" ) { // io.Reader
7990 name = name [8 :]
80- handler , inMap := readerRegister [name ]
81- if handler != nil {
91+ if handler , inMap := readerRegister [name ]; inMap {
8292 rdr = handler ()
83- }
84- if rdr == nil {
85- if ! inMap {
86- err = fmt .Errorf ("Reader '%s' is not registered" , name )
93+ if rdr != nil {
94+ data = make ([]byte , 4 + mc .maxWriteSize )
95+
96+ if cl , ok := rdr .(io.Closer ); ok {
97+ defer deferredClose (& err , cl )
98+ }
8799 } else {
88100 err = fmt .Errorf ("Reader '%s' is <nil>" , name )
89101 }
102+ } else {
103+ err = fmt .Errorf ("Reader '%s' is not registered" , name )
90104 }
91105 } else { // File
92106 name = strings .Trim (name , `"` )
93107 if mc .cfg .allowAllFiles || fileRegister [name ] {
94- rdr , err = os .Open (name )
108+ var file * os.File
109+ var fi os.FileInfo
110+
111+ if file , err = os .Open (name ); err == nil {
112+ defer deferredClose (& err , file )
113+
114+ // get file size
115+ if fi , err = file .Stat (); err == nil {
116+ rdr = file
117+ if fileSize := int (fi .Size ()); fileSize <= mc .maxWriteSize {
118+ data = make ([]byte , 4 + fileSize )
119+ } else if fileSize <= mc .maxPacketAllowed {
120+ data = make ([]byte , 4 + mc .maxWriteSize )
121+ } else {
122+ err = fmt .Errorf ("Local File '%s' too large: Size: %d, Max: %d" , name , fileSize , mc .maxPacketAllowed )
123+ }
124+ }
125+ }
95126 } else {
96127 err = fmt .Errorf ("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" , name )
97128 }
98129 }
99130
100- if rdc , ok := rdr .(io.ReadCloser ); ok {
101- defer func () {
102- if err == nil {
103- err = rdc .Close ()
104- } else {
105- rdc .Close ()
106- }
107- }()
108- }
109-
110131 // send content packets
111- var ioErr error
112132 if err == nil {
113133 var n int
114- for err == nil && ioErr == nil {
134+ for err == nil {
115135 n , err = rdr .Read (data [4 :])
116136 if n > 0 {
117- ioErr = mc .writePacket (data [:4 + n ])
137+ if ioErr := mc .writePacket (data [:4 + n ]); ioErr != nil {
138+ return ioErr
139+ }
118140 }
119141 }
120142 if err == io .EOF {
121143 err = nil
122144 }
123- if ioErr != nil {
124- errLog .Print (ioErr .Error ())
125- return driver .ErrBadConn
126- }
127145 }
128146
129147 // send empty packet (termination)
130- ioErr = mc .writePacket ([]byte {
131- 0x00 ,
132- 0x00 ,
133- 0x00 ,
134- mc .sequence ,
135- })
136- if ioErr != nil {
137- errLog .Print (ioErr .Error ())
138- return driver .ErrBadConn
148+ if data == nil {
149+ data = make ([]byte , 4 )
150+ }
151+ if ioErr := mc .writePacket (data [:4 ]); ioErr != nil {
152+ return ioErr
139153 }
140154
141155 // read OK packet
0 commit comments