Skip to content

Commit 29ea7cc

Browse files
committed
Added support for inner joins and a test
Signed-off-by: Zach Musgrave <zach@liquidata.co>
1 parent dcf8633 commit 29ea7cc

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

engine_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,6 +2192,19 @@ func TestInsertInto(t *testing.T) {
21922192
{int64(3), "third row"},
21932193
},
21942194
},
2195+
{
2196+
"INSERT INTO mytable (s,i) SELECT concat(m.s, o.s2), m.i from othertable o join mytable m on m.i=o.i2",
2197+
[]sql.Row{{int64(3)}},
2198+
"SELECT * FROM mytable order by i,s",
2199+
[]sql.Row{
2200+
{int64(1), "first row"},
2201+
{int64(1), "first rowthird"},
2202+
{int64(2), "second row"},
2203+
{int64(2), "second rowsecond"},
2204+
{int64(3), "third row"},
2205+
{int64(3), "third rowfirst"},
2206+
},
2207+
},
21952208
{
21962209
"INSERT INTO mytable (i,s) SELECT (i + 10.0) / 10.0 + 10, concat(s, ' new') from mytable",
21972210
[]sql.Row{{int64(3)}},
@@ -2272,6 +2285,10 @@ func TestInsertIntoErrors(t *testing.T) {
22722285
"column count mismatch in select",
22732286
"INSERT INTO mytable select s from othertable",
22742287
},
2288+
{
2289+
"column count mismatch in join select",
2290+
"INSERT INTO mytable (s,i) SELECT * from othertable o join mytable m on m.i=o.i2",
2291+
},
22752292
}
22762293

22772294
for _, expectedFailure := range expectedFailures {

sql/plan/insert.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
9292
p.Columns[i] = f.Name
9393
}
9494
} else {
95-
err = p.validateColumns(ctx, dstSchema)
95+
err = p.validateColumns(dstSchema)
9696
if err != nil {
9797
return 0, err
9898
}
@@ -139,7 +139,7 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
139139
return i, err
140140
}
141141

142-
err = p.validateNullability(ctx, dstSchema, row)
142+
err = p.validateNullability(dstSchema, row)
143143
if err != nil {
144144
_ = iter.Close()
145145
return i, err
@@ -227,13 +227,15 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
227227
return p.assertSchemasMatch(node.Schema())
228228
case *Project:
229229
return p.assertSchemasMatch(node.Schema())
230+
case *InnerJoin:
231+
return p.assertSchemasMatch(node.Schema())
230232
default:
231233
return ErrInsertIntoUnsupportedValues.New(node)
232234
}
233235
return nil
234236
}
235237

236-
func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) error {
238+
func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
237239
dstColNames := make(map[string]struct{})
238240
for _, dstCol := range dstSchema {
239241
dstColNames[dstCol.Name] = struct{}{}
@@ -252,7 +254,7 @@ func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) err
252254
return nil
253255
}
254256

255-
func (p *InsertInto) validateNullability(ctx *sql.Context, dstSchema sql.Schema, row sql.Row) error {
257+
func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) error {
256258
for i, col := range dstSchema {
257259
if !col.Nullable && row[i] == nil {
258260
return ErrInsertIntoNonNullableProvidedNull.New(col.Name)
@@ -262,7 +264,7 @@ func (p *InsertInto) validateNullability(ctx *sql.Context, dstSchema sql.Schema,
262264
}
263265

264266
func (p *InsertInto) assertSchemasMatch(schema sql.Schema) error {
265-
if len(p.Schema()) != len(p.Schema()) {
267+
if len(p.Columns) != len(schema) {
266268
return ErrInsertIntoMismatchValueCount.New()
267269
}
268270
return nil

0 commit comments

Comments
 (0)