Skip to content

Commit 9be9deb

Browse files
committed
Better type checking / conversion / error messages for insert. Introduced a Zero method for sql.Type which lets us compare types for compatibility by giving a default value for conversion.
Signed-off-by: Zach Musgrave <zach@liquidata.co>
1 parent 17eca8e commit 9be9deb

File tree

3 files changed

+151
-33
lines changed

3 files changed

+151
-33
lines changed

engine_test.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,9 +2054,9 @@ func TestInsertInto(t *testing.T) {
20542054
[]sql.Row{{
20552055
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
20562056
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
2057-
float64(math.MaxFloat32), float64(math.MaxFloat64),
2057+
float32(math.MaxFloat32), float64(math.MaxFloat64),
20582058
timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"),
2059-
"random text", true, `{"key":"value"}`, "blobdata",
2059+
"random text", true, ([]byte)(`{"key":"value"}`), ([]byte)("blobdata"),
20602060
}},
20612061
},
20622062
{
@@ -2072,9 +2072,9 @@ func TestInsertInto(t *testing.T) {
20722072
[]sql.Row{{
20732073
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
20742074
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
2075-
float64(math.MaxFloat32), float64(math.MaxFloat64),
2075+
float32(math.MaxFloat32), float64(math.MaxFloat64),
20762076
timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"),
2077-
"random text", true, `{"key":"value"}`, "blobdata",
2077+
"random text", true, ([]byte)(`{"key":"value"}`), ([]byte)("blobdata"),
20782078
}},
20792079
},
20802080
{
@@ -2090,9 +2090,9 @@ func TestInsertInto(t *testing.T) {
20902090
[]sql.Row{{
20912091
int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
20922092
uint8(0), uint16(0), uint32(0), uint64(0),
2093-
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
2093+
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
20942094
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
2095-
"", false, ``, "",
2095+
"", false, ([]byte)(`""`), ([]byte)(""),
20962096
}},
20972097
},
20982098
{
@@ -2108,9 +2108,9 @@ func TestInsertInto(t *testing.T) {
21082108
[]sql.Row{{
21092109
int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
21102110
uint8(0), uint16(0), uint32(0), uint64(0),
2111-
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
2111+
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
21122112
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
2113-
"", false, ``, "",
2113+
"", false, ([]byte)(`""`), ([]byte)(""),
21142114
}},
21152115
},
21162116
{
@@ -3070,7 +3070,8 @@ func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q stri
30703070
rows, err := sql.RowIterToRows(iter)
30713071
require.NoError(err)
30723072

3073-
if orderBy {
3073+
// .Equal gives better error messages than .ElementsMatch, so use it when possible
3074+
if orderBy || len(rows) == 1 {
30743075
require.Equal(expected, rows)
30753076
} else {
30763077
require.ElementsMatch(expected, rows)

sql/plan/insert.go

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@ var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v")
1717
var ErrInsertIntoNonexistentColumn = errors.NewKind("invalid column name %v")
1818
var ErrInsertIntoNonNullableDefaultNullColumn = errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null")
1919
var ErrInsertIntoNonNullableProvidedNull = errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null")
20+
var ErrInsertIntoIncompatibleTypes = errors.NewKind("cannot convert type %s to %s")
2021

2122
// InsertInto is a node describing the insertion into some table.
2223
type InsertInto struct {
2324
BinaryNode
24-
Columns []string
25-
IsReplace bool
25+
ColumnNames []string
26+
IsReplace bool
2627
}
2728

2829
// NewInsertInto creates an InsertInto node.
2930
func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string) *InsertInto {
3031
return &InsertInto{
31-
BinaryNode: BinaryNode{Left: dst, Right: src},
32-
Columns: cols,
33-
IsReplace: isReplace,
32+
BinaryNode: BinaryNode{Left: dst, Right: src},
33+
ColumnNames: cols,
34+
IsReplace: isReplace,
3435
}
3536
}
3637

@@ -83,13 +84,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
8384
}
8485

8586
dstSchema := p.Left.Schema()
86-
projExprs := make([]sql.Expression, len(dstSchema))
8787

8888
// If no columns are given, we assume the full schema in order
89-
if len(p.Columns) == 0 {
90-
p.Columns = make([]string, len(dstSchema))
89+
if len(p.ColumnNames) == 0 {
90+
p.ColumnNames = make([]string, len(dstSchema))
9191
for i, f := range dstSchema {
92-
p.Columns[i] = f.Name
92+
p.ColumnNames[i] = f.Name
9393
}
9494
} else {
9595
err = p.validateColumns(dstSchema)
@@ -103,9 +103,10 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
103103
return 0, err
104104
}
105105

106+
projExprs := make([]sql.Expression, len(dstSchema))
106107
for i, f := range dstSchema {
107108
found := false
108-
for j, col := range p.Columns {
109+
for j, col := range p.ColumnNames {
109110
if f.Name == col {
110111
projExprs[i] = expression.NewGetField(j, f.Type, f.Name, f.Nullable)
111112
found = true
@@ -121,9 +122,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
121122
}
122123
}
123124

124-
proj := NewProject(projExprs, p.Right)
125+
rowSource, err := p.rowSource(projExprs)
126+
if err != nil {
127+
return 0, err
128+
}
125129

126-
iter, err := proj.RowIter(ctx)
130+
iter, err := rowSource.RowIter(ctx)
127131
if err != nil {
128132
return 0, err
129133
}
@@ -145,11 +149,11 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
145149
return i, err
146150
}
147151

148-
// Convert integer values in row to specified type in schema
152+
// Convert values to the destination schema type
149153
for colIdx, oldValue := range row {
150154
dstColType := projExprs[colIdx].Type()
151155

152-
if sql.IsInteger(dstColType) && oldValue != nil {
156+
if oldValue != nil {
153157
newValue, err := dstColType.Convert(oldValue)
154158
if err != nil {
155159
return i, err
@@ -185,6 +189,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
185189
return i, nil
186190
}
187191

192+
func (p *InsertInto) rowSource(projExprs []sql.Expression) (sql.Node, error) {
193+
switch n := p.Right.(type) {
194+
case *Values:
195+
return NewProject(projExprs, n), nil
196+
case *ResolvedTable, *Project, *InnerJoin:
197+
if err := assertCompatibleSchemas(projExprs, n.Schema()); err != nil {
198+
return nil, err
199+
}
200+
return NewProject(projExprs, n), nil
201+
default:
202+
return nil, ErrInsertIntoUnsupportedValues.New(n)
203+
}
204+
}
205+
188206
// RowIter implements the Node interface.
189207
func (p *InsertInto) RowIter(ctx *sql.Context) (sql.RowIter, error) {
190208
n, err := p.Execute(ctx)
@@ -201,15 +219,15 @@ func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) {
201219
return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2)
202220
}
203221

204-
return NewInsertInto(children[0], children[1], p.IsReplace, p.Columns), nil
222+
return NewInsertInto(children[0], children[1], p.IsReplace, p.ColumnNames), nil
205223
}
206224

207225
func (p InsertInto) String() string {
208226
pr := sql.NewTreePrinter()
209227
if p.IsReplace {
210-
_ = pr.WriteNode("Replace(%s)", strings.Join(p.Columns, ", "))
228+
_ = pr.WriteNode("Replace(%s)", strings.Join(p.ColumnNames, ", "))
211229
} else {
212-
_ = pr.WriteNode("Insert(%s)", strings.Join(p.Columns, ", "))
230+
_ = pr.WriteNode("Insert(%s)", strings.Join(p.ColumnNames, ", "))
213231
}
214232
_ = pr.WriteChildren(p.Left.String(), p.Right.String())
215233
return pr.String()
@@ -219,16 +237,16 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
219237
switch node := p.Right.(type) {
220238
case *Values:
221239
for _, exprTuple := range node.ExpressionTuples {
222-
if len(exprTuple) != len(p.Columns) {
240+
if len(exprTuple) != len(p.ColumnNames) {
223241
return ErrInsertIntoMismatchValueCount.New()
224242
}
225243
}
226244
case *ResolvedTable:
227-
return p.assertSchemasMatch(node.Schema())
245+
return p.assertColumnCountsMatch(node.Schema())
228246
case *Project:
229-
return p.assertSchemasMatch(node.Schema())
247+
return p.assertColumnCountsMatch(node.Schema())
230248
case *InnerJoin:
231-
return p.assertSchemasMatch(node.Schema())
249+
return p.assertColumnCountsMatch(node.Schema())
232250
default:
233251
return ErrInsertIntoUnsupportedValues.New(node)
234252
}
@@ -241,7 +259,7 @@ func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
241259
dstColNames[dstCol.Name] = struct{}{}
242260
}
243261
columnNames := make(map[string]struct{})
244-
for _, columnName := range p.Columns {
262+
for _, columnName := range p.ColumnNames {
245263
if _, exists := dstColNames[columnName]; !exists {
246264
return ErrInsertIntoNonexistentColumn.New(columnName)
247265
}
@@ -263,9 +281,27 @@ func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) erro
263281
return nil
264282
}
265283

266-
func (p *InsertInto) assertSchemasMatch(schema sql.Schema) error {
267-
if len(p.Columns) != len(schema) {
284+
func (p *InsertInto) assertColumnCountsMatch(schema sql.Schema) error {
285+
if len(p.ColumnNames) != len(schema) {
268286
return ErrInsertIntoMismatchValueCount.New()
269287
}
270288
return nil
271289
}
290+
291+
func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) error {
292+
for _, expr := range projExprs {
293+
switch e := expr.(type) {
294+
case *expression.Literal:
295+
continue
296+
case *expression.GetField:
297+
otherCol := schema[e.Index()]
298+
_, err := otherCol.Type.Convert(expr.Type().Zero())
299+
if err != nil {
300+
return ErrInsertIntoIncompatibleTypes.New(otherCol.Type.String(), expr.Type().String())
301+
}
302+
default:
303+
return ErrInsertIntoUnsupportedValues.New(expr)
304+
}
305+
}
306+
return nil
307+
}

0 commit comments

Comments
 (0)