@@ -17,20 +17,21 @@ var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v")
1717var ErrInsertIntoNonexistentColumn = errors .NewKind ("invalid column name %v" )
1818var ErrInsertIntoNonNullableDefaultNullColumn = errors .NewKind ("column name '%v' is non-nullable but attempted to set default value of null" )
1919var ErrInsertIntoNonNullableProvidedNull = errors .NewKind ("column name '%v' is non-nullable but attempted to set a value of null" )
20+ var ErrInsertIntoIncompatibleTypes = errors .NewKind ("cannot convert type %s to %s" )
2021
2122// InsertInto is a node describing the insertion into some table.
2223type InsertInto struct {
2324 BinaryNode
24- Columns []string
25- IsReplace bool
25+ ColumnNames []string
26+ IsReplace bool
2627}
2728
2829// NewInsertInto creates an InsertInto node.
2930func NewInsertInto (dst , src sql.Node , isReplace bool , cols []string ) * InsertInto {
3031 return & InsertInto {
31- BinaryNode : BinaryNode {Left : dst , Right : src },
32- Columns : cols ,
33- IsReplace : isReplace ,
32+ BinaryNode : BinaryNode {Left : dst , Right : src },
33+ ColumnNames : cols ,
34+ IsReplace : isReplace ,
3435 }
3536}
3637
@@ -83,13 +84,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
8384 }
8485
8586 dstSchema := p .Left .Schema ()
86- projExprs := make ([]sql.Expression , len (dstSchema ))
8787
8888 // If no columns are given, we assume the full schema in order
89- if len (p .Columns ) == 0 {
90- p .Columns = make ([]string , len (dstSchema ))
89+ if len (p .ColumnNames ) == 0 {
90+ p .ColumnNames = make ([]string , len (dstSchema ))
9191 for i , f := range dstSchema {
92- p .Columns [i ] = f .Name
92+ p .ColumnNames [i ] = f .Name
9393 }
9494 } else {
9595 err = p .validateColumns (dstSchema )
@@ -103,9 +103,10 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
103103 return 0 , err
104104 }
105105
106+ projExprs := make ([]sql.Expression , len (dstSchema ))
106107 for i , f := range dstSchema {
107108 found := false
108- for j , col := range p .Columns {
109+ for j , col := range p .ColumnNames {
109110 if f .Name == col {
110111 projExprs [i ] = expression .NewGetField (j , f .Type , f .Name , f .Nullable )
111112 found = true
@@ -121,9 +122,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
121122 }
122123 }
123124
124- proj := NewProject (projExprs , p .Right )
125+ rowSource , err := p .rowSource (projExprs )
126+ if err != nil {
127+ return 0 , err
128+ }
125129
126- iter , err := proj .RowIter (ctx )
130+ iter , err := rowSource .RowIter (ctx )
127131 if err != nil {
128132 return 0 , err
129133 }
@@ -145,11 +149,11 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
145149 return i , err
146150 }
147151
148- // Convert integer values in row to specified type in schema
152+ // Convert values to the destination schema type
149153 for colIdx , oldValue := range row {
150154 dstColType := projExprs [colIdx ].Type ()
151155
152- if sql . IsInteger ( dstColType ) && oldValue != nil {
156+ if oldValue != nil {
153157 newValue , err := dstColType .Convert (oldValue )
154158 if err != nil {
155159 return i , err
@@ -185,6 +189,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
185189 return i , nil
186190}
187191
192+ func (p * InsertInto ) rowSource (projExprs []sql.Expression ) (sql.Node , error ) {
193+ switch n := p .Right .(type ) {
194+ case * Values :
195+ return NewProject (projExprs , n ), nil
196+ case * ResolvedTable , * Project , * InnerJoin :
197+ if err := assertCompatibleSchemas (projExprs , n .Schema ()); err != nil {
198+ return nil , err
199+ }
200+ return NewProject (projExprs , n ), nil
201+ default :
202+ return nil , ErrInsertIntoUnsupportedValues .New (n )
203+ }
204+ }
205+
188206// RowIter implements the Node interface.
189207func (p * InsertInto ) RowIter (ctx * sql.Context ) (sql.RowIter , error ) {
190208 n , err := p .Execute (ctx )
@@ -201,15 +219,15 @@ func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) {
201219 return nil , sql .ErrInvalidChildrenNumber .New (p , len (children ), 2 )
202220 }
203221
204- return NewInsertInto (children [0 ], children [1 ], p .IsReplace , p .Columns ), nil
222+ return NewInsertInto (children [0 ], children [1 ], p .IsReplace , p .ColumnNames ), nil
205223}
206224
207225func (p InsertInto ) String () string {
208226 pr := sql .NewTreePrinter ()
209227 if p .IsReplace {
210- _ = pr .WriteNode ("Replace(%s)" , strings .Join (p .Columns , ", " ))
228+ _ = pr .WriteNode ("Replace(%s)" , strings .Join (p .ColumnNames , ", " ))
211229 } else {
212- _ = pr .WriteNode ("Insert(%s)" , strings .Join (p .Columns , ", " ))
230+ _ = pr .WriteNode ("Insert(%s)" , strings .Join (p .ColumnNames , ", " ))
213231 }
214232 _ = pr .WriteChildren (p .Left .String (), p .Right .String ())
215233 return pr .String ()
@@ -219,16 +237,16 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
219237 switch node := p .Right .(type ) {
220238 case * Values :
221239 for _ , exprTuple := range node .ExpressionTuples {
222- if len (exprTuple ) != len (p .Columns ) {
240+ if len (exprTuple ) != len (p .ColumnNames ) {
223241 return ErrInsertIntoMismatchValueCount .New ()
224242 }
225243 }
226244 case * ResolvedTable :
227- return p .assertSchemasMatch (node .Schema ())
245+ return p .assertColumnCountsMatch (node .Schema ())
228246 case * Project :
229- return p .assertSchemasMatch (node .Schema ())
247+ return p .assertColumnCountsMatch (node .Schema ())
230248 case * InnerJoin :
231- return p .assertSchemasMatch (node .Schema ())
249+ return p .assertColumnCountsMatch (node .Schema ())
232250 default :
233251 return ErrInsertIntoUnsupportedValues .New (node )
234252 }
@@ -241,7 +259,7 @@ func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
241259 dstColNames [dstCol .Name ] = struct {}{}
242260 }
243261 columnNames := make (map [string ]struct {})
244- for _ , columnName := range p .Columns {
262+ for _ , columnName := range p .ColumnNames {
245263 if _ , exists := dstColNames [columnName ]; ! exists {
246264 return ErrInsertIntoNonexistentColumn .New (columnName )
247265 }
@@ -263,9 +281,27 @@ func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) erro
263281 return nil
264282}
265283
266- func (p * InsertInto ) assertSchemasMatch (schema sql.Schema ) error {
267- if len (p .Columns ) != len (schema ) {
284+ func (p * InsertInto ) assertColumnCountsMatch (schema sql.Schema ) error {
285+ if len (p .ColumnNames ) != len (schema ) {
268286 return ErrInsertIntoMismatchValueCount .New ()
269287 }
270288 return nil
271289}
290+
291+ func assertCompatibleSchemas (projExprs []sql.Expression , schema sql.Schema ) error {
292+ for _ , expr := range projExprs {
293+ switch e := expr .(type ) {
294+ case * expression.Literal :
295+ continue
296+ case * expression.GetField :
297+ otherCol := schema [e .Index ()]
298+ _ , err := otherCol .Type .Convert (expr .Type ().Zero ())
299+ if err != nil {
300+ return ErrInsertIntoIncompatibleTypes .New (otherCol .Type .String (), expr .Type ().String ())
301+ }
302+ default :
303+ return ErrInsertIntoUnsupportedValues .New (expr )
304+ }
305+ }
306+ return nil
307+ }
0 commit comments