diff --git a/_example/main.go b/_example/main.go index 1244c93239..0319f310fa 100644 --- a/_example/main.go +++ b/_example/main.go @@ -89,6 +89,7 @@ func createTestDatabase() *memory.DbProvider { pro := memory.NewDBProvider(db) session := memory.NewSession(sql.NewBaseSession(), pro) ctx := sql.NewContext(context.Background(), sql.WithSession(session)) + ctx.Session = session table := memory.NewTable(db, tableName, sql.NewPrimaryKeySchema(sql.Schema{ {Name: "name", Type: types.Text, Nullable: false, Source: tableName, PrimaryKey: true}, diff --git a/driver/driver.go b/driver/driver.go index 05ac3d093c..a3232c7c3c 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -162,7 +162,7 @@ func (d *Driver) OpenConnector(dsn string) (driver.Connector, error) { d.mu.Lock() db, ok := d.dbs[serverName] if !ok { - anlz := analyzer.NewDefaultWithVersion(pro) + anlz := analyzer.NewDefault(pro) engine := sqle.New(anlz, nil) db = &dbConn{engine: engine} d.dbs[serverName] = db diff --git a/engine.go b/engine.go index 1b9895e9f6..508f7021d5 100644 --- a/engine.go +++ b/engine.go @@ -183,6 +183,10 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine { } a.Catalog.RegisterFunction(emptyCtx, function.GetLockingFuncs(ls)...) + parser := sql.DefaultMySQLParser + if a.Overrides.Builder.Parser != nil { + parser = a.Overrides.Builder.Parser + } ret := &Engine{ Analyzer: a, @@ -194,7 +198,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine { PreparedDataCache: NewPreparedDataCache(), mu: &sync.Mutex{}, EventScheduler: nil, - Parser: sql.GlobalParser, + Parser: parser, } ret.ReadOnly.Store(cfg.IsReadOnly) a.Runner = ret @@ -203,8 +207,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine { // NewDefault creates a new default Engine. func NewDefault(pro sql.DatabaseProvider) *Engine { - a := analyzer.NewDefaultWithVersion(pro) - return New(a, nil) + return New(analyzer.NewDefault(pro), nil) } // AnalyzeQuery analyzes a query and returns its sql.Node @@ -212,7 +215,7 @@ func (e *Engine) AnalyzeQuery( ctx *sql.Context, query string, ) (sql.Node, error) { - binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler, e.Parser) + binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler) parsed, _, _, qFlags, err := binder.Parse(query, nil, false) if err != nil { return nil, err @@ -243,7 +246,7 @@ func (e *Engine) PrepareParsedQuery( // Make sure there is an active transaction if one hasn't been started yet e.beginTransaction(ctx) - binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler, e.Parser) + binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler) node, _, err := binder.BindOnly(stmt, query, nil) if err != nil { @@ -526,7 +529,7 @@ func (e *Engine) BoundQueryPlan(ctx *sql.Context, query string, parsed sqlparser query = sql.RemoveSpaceAndDelimiter(query, ';') - binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler, e.Parser) + binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler) binder.SetBindings(bindings) // Begin a transaction if necessary (no-op if one is in flight) @@ -580,7 +583,7 @@ func (e *Engine) preparedStatement(ctx *sql.Context, query string, parsed sqlpar preparedAst, preparedDataFound = e.PreparedDataCache.GetCachedStmt(ctx.Session.ID(), query) } - binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler, e.Parser) + binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.EventScheduler) if preparedDataFound { parsed = preparedAst binder.SetBindings(bindings) diff --git a/engine_test.go b/engine_test.go index 4621bceb83..bb52a09697 100755 --- a/engine_test.go +++ b/engine_test.go @@ -231,7 +231,7 @@ func TestTrackProcess(t *testing.T) { _, ok = rhs.Table.(*plan.ProcessTable) require.True(ok) - iter, err := rowexec.DefaultBuilder.Build(ctx, result, nil) + iter, err := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(ctx, result, nil) require.NoError(err) iter, _, err = rowexec.FinalizeIters(ctx, result, nil, iter) require.NoError(err) diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 86aa287be1..99ad35dc53 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -272,7 +272,7 @@ func TestShowProcessList(t *testing.T) { n := plan.NewShowProcessList() - iter, err := rowexec.DefaultBuilder.Build(ctx, n, nil) + iter, err := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(ctx, n, nil) require.NoError(err) rows, err := sql.RowIterToRows(ctx, iter) require.NoError(err) @@ -327,9 +327,9 @@ func TestLockTables(t *testing.T) { {plan.NewResolvedTable(t1, nil, nil), true}, {plan.NewResolvedTable(t2, nil, nil), false}, }) - node.Catalog = analyzer.NewCatalog(sql.NewDatabaseProvider()) + node.Catalog = analyzer.NewCatalog(sql.NewDatabaseProvider(), sql.EngineOverrides{}) - _, err := rowexec.DefaultBuilder.Build(sql.NewEmptyContext(), node, nil) + _, err := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(sql.NewEmptyContext(), node, nil) require.NoError(err) @@ -350,9 +350,9 @@ func TestUnlockTables(t *testing.T) { db.AddTable("bar", t2) db.AddTable("baz", t3) - catalog := analyzer.NewCatalog(sql.NewDatabaseProvider(db)) + catalog := analyzer.NewCatalog(sql.NewDatabaseProvider(db), sql.EngineOverrides{}) - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() ctx.SetCurrentDatabase("db") catalog.LockTable(ctx, "foo") catalog.LockTable(ctx, "bar") @@ -438,7 +438,7 @@ func TestAnalyzer_Exp(t *testing.T) { require.NoError(t, err) ctx := enginetest.NewContext(harness) - b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, e.EngineEventScheduler(), nil) + b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, e.EngineEventScheduler()) parsed, _, _, _, err := b.Parse(tt.query, nil, false) require.NoError(t, err) @@ -591,7 +591,7 @@ func TestTableFunctions(t *testing.T) { engine := enginetest.NewEngineWithProvider(t, harness, testDatabaseProvider) harness = harness.WithProvider(engine.Analyzer.Catalog.DbProvider) - engine.EngineAnalyzer().ExecBuilder = rowexec.DefaultBuilder + engine.EngineAnalyzer().ExecBuilder = rowexec.NewBuilder(nil, sql.EngineOverrides{}) engine, err := enginetest.RunSetupScripts(harness.NewContext(), engine, setup.MydbData, true) require.NoError(t, err) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 50c7aa3f5c..689bf30d78 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -741,7 +741,8 @@ func TestQueryPlan(t *testing.T, harness Harness, e QueryEngine, tt queries.Quer func TestQueryPlanWithName(t *testing.T, name string, harness Harness, e QueryEngine, query, expectedPlan string, options sql.DescribeOptions) { t.Run(name, func(t *testing.T) { ctx := NewContext(harness) - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, query) + builder := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, nil) + parsed, _, _, qFlags, err := builder.Parse(query, nil, false) require.NoError(t, err) node, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) @@ -768,7 +769,8 @@ func TestQueryPlanWithName(t *testing.T, name string, harness Harness, e QueryEn func TestQueryPlanWithEngine(t *testing.T, harness Harness, e QueryEngine, tt queries.QueryPlanTest, verbose bool) { t.Run(tt.Query, func(t *testing.T) { ctx := NewContext(harness) - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, tt.Query) + builder := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, nil) + parsed, _, _, qFlags, err := builder.Parse(tt.Query, nil, false) require.NoError(t, err) node, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) @@ -1482,6 +1484,7 @@ func TestTruncate(t *testing.T, harness Harness) { e := mustNewEngine(t, harness) defer e.Close() ctx := NewContext(harness) + builder := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, nil) t.Run("Standard TRUNCATE", func(t *testing.T) { RunQueryWithContext(t, e, harness, ctx, "CREATE TABLE t1 (pk BIGINT PRIMARY KEY, v1 BIGINT, INDEX(v1))") @@ -1530,7 +1533,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t5 ORDER BY 1", []sql.Row{{int64(1), int64(1)}, {int64(2), int64(2)}}, nil, nil, nil) deleteStr := "DELETE FROM t5" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1559,7 +1562,8 @@ func TestTruncate(t *testing.T, harness Harness) { RunQueryWithContext(t, e, harness, ctx, "INSERT INTO t6parent VALUES (1,1), (2,2)") RunQueryWithContext(t, e, harness, ctx, "INSERT INTO t6child VALUES (1,1), (2,2)") - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, "DELETE FROM t6parent") + deleteStr := "DELETE FROM t6parent" + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1587,7 +1591,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t7i ORDER BY 1", []sql.Row{{int64(3), int64(3)}}, nil, nil, nil) deleteStr := "DELETE FROM t7" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1615,7 +1619,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t8 ORDER BY 1", []sql.Row{{int64(1), int64(4)}, {int64(2), int64(5)}}, nil, nil, nil) deleteStr := "DELETE FROM t8" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1644,7 +1648,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t9 ORDER BY 1", []sql.Row{{int64(7), int64(7)}, {int64(8), int64(8)}}, nil, nil, nil) deleteStr := "DELETE FROM t9 WHERE pk > 0" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1671,7 +1675,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t10 ORDER BY 1", []sql.Row{{int64(8), int64(8)}, {int64(9), int64(9)}}, nil, nil, nil) deleteStr := "DELETE FROM t10 LIMIT 1000" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1698,7 +1702,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t11 ORDER BY 1", []sql.Row{{int64(1), int64(1)}, {int64(9), int64(9)}}, nil, nil, nil) deleteStr := "DELETE FROM t11 ORDER BY 1" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1729,7 +1733,7 @@ func TestTruncate(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t12b ORDER BY 1", []sql.Row{{int64(1), int64(1)}, {int64(2), int64(2)}}, nil, nil, nil) deleteStr := "DELETE t12a, t12b FROM t12a INNER JOIN t12b WHERE t12a.pk=t12b.pk" - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr) + parsed, _, _, qFlags, err := builder.Parse(deleteStr, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags) require.NoError(t, err) @@ -1956,7 +1960,7 @@ func TestUserPrivileges(t *testing.T, harness ClientHarness) { defer engine.Close() ctx := NewContext(harness) - ctx.NewCtxWithClient(sql.Client{ + ctx.WithClient(sql.Client{ User: "root", Address: "localhost", }) @@ -2055,7 +2059,7 @@ func TestUserPrivileges(t *testing.T, harness ClientHarness) { t.Skipf("Skipping query %s", lastQuery) } } - ctx := rootCtx.NewCtxWithClient(sql.Client{ + ctx := rootCtx.WithClient(sql.Client{ User: "tester", Address: "localhost", }) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 022da6524e..e7c394b465 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -627,7 +627,7 @@ func injectBindVarsAndPrepare( } } - b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, e.EngineEventScheduler(), nil) + b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, e.EngineEventScheduler()) b.SetParserOptions(sql.LoadSqlMode(ctx).ParserOptions()) resPlan, err := e.PrepareQuery(ctx, q) if err != nil { diff --git a/enginetest/histogram_test.go b/enginetest/histogram_test.go index 49c42797a9..257c9b95da 100644 --- a/enginetest/histogram_test.go +++ b/enginetest/histogram_test.go @@ -315,7 +315,7 @@ func testHistogram(ctx *sql.Context, table *plan.ResolvedTable, fields []int, bu return nil, fmt.Errorf("found zero row count for table") } - i, err := rowexec.DefaultBuilder.Build(ctx, table, nil) + i, err := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(ctx, table, nil) rows, err := sql.RowIterToRows(ctx, i) if err != nil { return nil, err @@ -416,7 +416,7 @@ func expectedResultSize(ctx *sql.Context, t1, t2 *plan.ResolvedTable, filters [] if debug { fmt.Println(sql.DebugString(j)) } - i, err := rowexec.DefaultBuilder.Build(ctx, j, nil) + i, err := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(ctx, j, nil) if err != nil { return 0, err } diff --git a/enginetest/join_planning_tests.go b/enginetest/join_planning_tests.go index af01bc58a5..aeed077f98 100644 --- a/enginetest/join_planning_tests.go +++ b/enginetest/join_planning_tests.go @@ -1901,7 +1901,8 @@ func evalJoinTypeTest(t *testing.T, harness Harness, e QueryEngine, query string } func analyzeQuery(ctx *sql.Context, e QueryEngine, query string) (sql.Node, error) { - parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, query) + builder := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, nil) + parsed, _, _, qFlags, err := builder.Parse(query, nil, false) if err != nil { return nil, err } diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index f1ec7b45d0..a3a3543043 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,20 +203,27 @@ func TestSingleScript(t *testing.T) { t.Skip() var scripts = []queries.ScriptTest{ { - Name: "AS OF propagates to nested CALLs", - SetUpScript: []string{}, + Name: "Parse table name as column", + SetUpScript: []string{ + `CREATE TABLE test (pk INT PRIMARY KEY, v1 VARCHAR(255));`, + `INSERT INTO test VALUES (1, 'a'), (2, 'b');`, + }, Assertions: []queries.ScriptTestAssertion{ { - Query: "create procedure create_proc() create table t (i int primary key, j int);", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, + Query: "SELECT temporarytesting(t) FROM test AS t;", + Expected: []sql.Row{}, }, { - Query: "call create_proc()", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, + Query: "SELECT temporarytesting(test) FROM test;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT temporarytesting(pk, test) FROM test;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT temporarytesting(v1, test, pk) FROM test;", + Expected: []sql.Row{}, }, }, }, diff --git a/enginetest/mysqlshim/table.go b/enginetest/mysqlshim/table.go index 17bec8909e..f28277822a 100644 --- a/enginetest/mysqlshim/table.go +++ b/enginetest/mysqlshim/table.go @@ -381,7 +381,8 @@ func (t Table) getCreateTable() (*plan.CreateTable, error) { return nil, sql.ErrTableNotFound.New(t.name) } // TODO add catalog - createTableNode, _, err := planbuilder.Parse(sql.NewEmptyContext(), sql.MapCatalog{Tables: map[string]sql.Table{t.name: t}}, rows[0][1].(string)) + builder := planbuilder.New(sql.NewEmptyContext(), sql.MapCatalog{Tables: map[string]sql.Table{t.name: t}}, nil) + createTableNode, _, _, _, err := builder.Parse(rows[0][1].(string), nil, false) if err != nil { return nil, err } diff --git a/enginetest/plangen/cmd/plangen/main.go b/enginetest/plangen/cmd/plangen/main.go index 9dadbc78df..d154d1ad72 100644 --- a/enginetest/plangen/cmd/plangen/main.go +++ b/enginetest/plangen/cmd/plangen/main.go @@ -154,7 +154,7 @@ func writePlanString(w *bytes.Buffer, planString string) { } func analyzeQuery(ctx *sql.Context, engine enginetest.QueryEngine, query string) sql.Node { - binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, engine.EngineEventScheduler(), nil) + binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, engine.EngineEventScheduler()) parsed, _, _, qFlags, err := binder.Parse(query, nil, false) if err != nil { exit(fmt.Errorf("%w\nfailed to parse query: %s", err, query)) diff --git a/enginetest/server_engine.go b/enginetest/server_engine.go index 679c65e662..0a2d3284f6 100644 --- a/enginetest/server_engine.go +++ b/enginetest/server_engine.go @@ -705,8 +705,8 @@ func convertGoSqlType(columnType *gosql.ColumnType) (sql.Type, error) { // is in random order. The function expects binding variables starting with `:v1` and do not skip number. // It cannot sort user-defined binding variables (e.g. :var, :foo) func prepareBindingArgs(ctx *sql.Context, bindings map[string]sqlparser.Expr) ([]any, error) { - // NOTE: using binder with nil catalog and parser since we're only using it to convert SQLVal. - binder := planbuilder.New(ctx, nil, nil, nil) + // NOTE: using binder with nil catalog since we're only using it to convert SQLVal. + binder := planbuilder.New(ctx, nil, nil) numBindVars := len(bindings) args := make([]any, numBindVars) for i := 0; i < numBindVars; i++ { diff --git a/memory/table.go b/memory/table.go index a7c6a0584f..feaf44fdbd 100644 --- a/memory/table.go +++ b/memory/table.go @@ -128,7 +128,7 @@ func stripTblNames(e sql.Expression) (sql.Expression, transform.TreeIdentity, er case *expression.GetField: // strip table names ne := expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable()) - ne = ne.WithQuotedNames(sql.GlobalSchemaFormatter, e.IsQuotedIdentifier()) + ne = ne.WithQuotedNames(sql.DefaultMySQLSchemaFormatter, e.IsQuotedIdentifier()) return ne, transform.NewTree, nil default: } diff --git a/server/context.go b/server/context.go index b4c137bcaf..abfd209e57 100644 --- a/server/context.go +++ b/server/context.go @@ -279,7 +279,7 @@ func (s *SessionManager) NewContextWithQuery(ctx context.Context, conn *mysql.Co ctx, span := s.tracer.Start(ctx, "query") - context := s.ctxFactory( + createdCtx := s.ctxFactory( ctx, sql.WithSession(sess), sql.WithTracer(s.tracer), @@ -294,7 +294,7 @@ func (s *SessionManager) NewContextWithQuery(ctx context.Context, conn *mysql.Co }), ) - return context, nil + return createdCtx, nil } // Exposed through (*sql.Context).Services.KillConnection. Calls Close on the diff --git a/server/handler.go b/server/handler.go index 477476a939..14b40582d8 100644 --- a/server/handler.go +++ b/server/handler.go @@ -387,7 +387,7 @@ func (h *Handler) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaP return nil } - newCtx := sql.NewContext(context.Background()) + newCtx := sql.NewEmptyContext() primaryController := h.e.Analyzer.Catalog.GetBinlogPrimaryController() return primaryController.RegisterReplica(newCtx, c, replicaHost, replicaPort) } @@ -403,7 +403,7 @@ func (h *Handler) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64 Debug("Handling COM_BINLOG_DUMP_GTID") // TODO: is logfile and logpos ever actually needed for COM_BINLOG_DUMP_GTID? - newCtx := sql.NewContext(context.Background()) + newCtx := sql.NewEmptyContext() primaryController := h.e.Analyzer.Catalog.GetBinlogPrimaryController() return primaryController.BinlogDumpGtid(newCtx, c, gtidSet) } diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 0e710c26e3..c053c5e35c 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -76,6 +76,7 @@ type Builder struct { validationRules []Rule afterAllRules []Rule debug bool + overrides sql.EngineOverrides } // NewBuilder creates a new Builder from a specific catalog. @@ -129,6 +130,12 @@ func (ab *Builder) AddPostValidationRule(id RuleId, fn RuleFunc) *Builder { return ab } +// AddOverrides adds the given overrides to the builder. +func (ab *Builder) AddOverrides(overrides sql.EngineOverrides) *Builder { + ab.overrides = overrides + return ab +} + func duplicateRulesWithout(rules []Rule, excludedRuleId RuleId) []Rule { newRules := make([]Rule, 0, len(rules)) @@ -264,6 +271,14 @@ func (ab *Builder) Build() *Analyzer { Rules: ab.afterAllRules, }, } + parser := sql.DefaultMySQLParser + if ab.overrides.Builder.Parser != nil { + parser = ab.overrides.Builder.Parser + } + schemaFormatter := sql.DefaultMySQLSchemaFormatter + if ab.overrides.SchemaFormatter != nil { + schemaFormatter = ab.overrides.SchemaFormatter + } return &Analyzer{ Debug: debug || ab.debug, @@ -271,11 +286,12 @@ func (ab *Builder) Build() *Analyzer { Trace: trace, contextStack: make([]string, 0), Batches: batches, - Catalog: NewCatalog(ab.provider), + Catalog: NewCatalog(ab.provider, ab.overrides), + Overrides: ab.overrides, Coster: memo.NewDefaultCoster(), - ExecBuilder: rowexec.DefaultBuilder, - Parser: sql.GlobalParser, - SchemaFormatter: sql.GlobalSchemaFormatter, + ExecBuilder: rowexec.NewBuilder(nil, ab.overrides), + Parser: parser, + SchemaFormatter: schemaFormatter, } } @@ -294,6 +310,8 @@ type Analyzer struct { SchemaFormatter sql.SchemaFormatter // Catalog of databases and registered functions. Catalog *Catalog + // Overrides contains the overrides for the engine. + Overrides sql.EngineOverrides // A stack of debugger context. See PushDebugContext, PopDebugContext contextStack []string // Batches of Rules to apply. @@ -312,12 +330,6 @@ func NewDefault(provider sql.DatabaseProvider) *Analyzer { return NewBuilder(provider).Build() } -// NewDefaultWithVersion creates a default Analyzer instance either -// experimental or -func NewDefaultWithVersion(provider sql.DatabaseProvider) *Analyzer { - return NewBuilder(provider).Build() -} - // Log prints an INFO message to stdout with the given message and args // if the analyzer is in debug mode. func (a *Analyzer) Log(msg string, args ...interface{}) { diff --git a/sql/analyzer/apply_foreign_keys.go b/sql/analyzer/apply_foreign_keys.go index 5fddbd11fe..acb12f887e 100644 --- a/sql/analyzer/apply_foreign_keys.go +++ b/sql/analyzer/apply_foreign_keys.go @@ -480,7 +480,7 @@ func getForeignKeyHandlerFromUpdateTarget(ctx *sql.Context, a *Analyzer, updateT // be evaluated in the context of a single table. func resolveSchemaDefaults(ctx *sql.Context, catalog *Catalog, table sql.Table) (sql.Schema, error) { // Resolve any column default expressions in tblSch - builder := planbuilder.New(ctx, catalog, nil, sql.GlobalParser) + builder := planbuilder.New(ctx, catalog, nil) childTblSch := builder.ResolveSchemaDefaults(ctx.GetCurrentDatabase(), table.Name(), table.Schema()) // Field Indexes are off by one initially and don't fixed by assignExecIndexes because it doesn't traverse through diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index cd5c08b896..5e393289b0 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -45,6 +45,7 @@ type Catalog struct { MySQLDb *mysql_db.MySQLDb builtInFunctions function.Registry + overrides sql.EngineOverrides locks sessionLocks mu sync.RWMutex @@ -66,12 +67,13 @@ type dbLocks map[string]tableLocks type sessionLocks map[uint32]dbLocks // NewCatalog returns a new empty Catalog with the given provider -func NewCatalog(provider sql.DatabaseProvider) *Catalog { +func NewCatalog(provider sql.DatabaseProvider, overrides sql.EngineOverrides) *Catalog { c := &Catalog{ MySQLDb: mysql_db.CreateEmptyMySQLDb(), InfoSchema: information_schema.NewInformationSchemaDatabase(), DbProvider: provider, builtInFunctions: function.NewRegistry(), + overrides: overrides, StatsProvider: memory.NewStatsProv(), locks: make(sessionLocks), } @@ -415,6 +417,11 @@ func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunctio return nil, false } +// Overrides implements the sql.Catalog interface +func (c *Catalog) Overrides() sql.EngineOverrides { + return c.overrides +} + func (c *Catalog) AnalyzeTable(ctx *sql.Context, table sql.Table, db string) error { return c.StatsProvider.AnalyzeTable(ctx, table, db) } diff --git a/sql/analyzer/catalog_locks_test.go b/sql/analyzer/catalog_locks_test.go index e2bc4c48e9..166a711ba8 100644 --- a/sql/analyzer/catalog_locks_test.go +++ b/sql/analyzer/catalog_locks_test.go @@ -15,7 +15,6 @@ package analyzer import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -25,11 +24,11 @@ import ( func TestCatalogLockTable(t *testing.T) { require := require.New(t) - c := NewCatalog(NewDatabaseProvider()) + c := NewCatalog(NewDatabaseProvider(), sql.EngineOverrides{}) - ctx1 := sql.NewContext(context.Background()) + ctx1 := sql.NewEmptyContext() ctx1.SetCurrentDatabase("db1") - ctx2 := sql.NewContext(context.Background()) + ctx2 := sql.NewEmptyContext() ctx2.SetCurrentDatabase("db1") c.LockTable(ctx1, "foo") diff --git a/sql/analyzer/catalog_test.go b/sql/analyzer/catalog_test.go index 426496e75e..0cf72e7ea3 100644 --- a/sql/analyzer/catalog_test.go +++ b/sql/analyzer/catalog_test.go @@ -32,7 +32,7 @@ func TestAllDatabases(t *testing.T) { memory.NewDatabase("c"), } - c := NewCatalog(sql.NewDatabaseProvider(dbs...)) + c := NewCatalog(sql.NewDatabaseProvider(dbs...), sql.EngineOverrides{}) databases := c.AllDatabases(sql.NewEmptyContext()) require.Equal(4, len(databases)) @@ -44,7 +44,7 @@ func TestCatalogDatabase(t *testing.T) { require := require.New(t) mydb := memory.NewDatabase("foo") - c := NewCatalog(sql.NewDatabaseProvider(mydb)) + c := NewCatalog(sql.NewDatabaseProvider(mydb), sql.EngineOverrides{}) db, err := c.Database(sql.NewEmptyContext(), "flo") require.EqualError(err, "database not found: flo, maybe you mean foo?") @@ -61,7 +61,7 @@ func TestCatalogTable(t *testing.T) { db := memory.NewDatabase("foo") pro := memory.NewDBProvider(db) ctx := newContext(pro) - c := NewCatalog(pro) + c := NewCatalog(pro, sql.EngineOverrides{}) table, _, err := c.Table(ctx, "foo", "bar") require.EqualError(err, "table not found: bar") @@ -87,7 +87,7 @@ func TestCatalogUnlockTables(t *testing.T) { require := require.New(t) db := memory.NewDatabase("db") pro := memory.NewDBProvider(db) - c := NewCatalog(pro) + c := NewCatalog(pro, sql.EngineOverrides{}) ctx := newContext(pro) t1 := newLockableTable(memory.NewTable(db, "t1", sql.PrimaryKeySchema{}, db.GetForeignKeyCollection())) diff --git a/sql/analyzer/engine_overrides.go b/sql/analyzer/engine_overrides.go new file mode 100644 index 0000000000..c5ad4a942a --- /dev/null +++ b/sql/analyzer/engine_overrides.go @@ -0,0 +1,48 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzer + +import ( + "github.com/dolthub/go-mysql-server/sql/transform" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" +) + +// engineOverrides handles adding the engine overrides to any nodes or expressions that implement the corresponding +// interface. +func engineOverrides(_ *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, _ *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + newNode, sameNode, err := transform.NodeWithOpaque(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + overriding, ok := n.(sql.NodeOverriding) + if !ok { + return n, transform.SameTree, nil + } + return overriding.WithOverrides(a.Overrides), transform.NewTree, nil + }) + if err != nil { + return nil, transform.SameTree, err + } + newNode, sameExpr, err := transform.NodeExprsWithOpaque(newNode, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + overriding, ok := e.(sql.ExpressionOverriding) + if !ok { + return e, transform.SameTree, nil + } + return overriding.WithOverrides(a.Overrides), transform.NewTree, nil + }) + if err != nil { + return nil, transform.SameTree, err + } + return newNode, sameNode && sameExpr, nil +} diff --git a/sql/analyzer/load_triggers.go b/sql/analyzer/load_triggers.go index 32cef54438..d95600d997 100644 --- a/sql/analyzer/load_triggers.go +++ b/sql/analyzer/load_triggers.go @@ -106,7 +106,9 @@ func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database, ignorePa var parsedTrigger sql.Node sqlMode := sql.NewSqlModeFromString(trigger.SqlMode) // TODO: should perhaps add the auth query handler to the analyzer? does this even use auth? - parsedTrigger, _, err = planbuilder.ParseWithOptions(ctx, a.Catalog, trigger.CreateStatement, sqlMode.ParserOptions()) + builder := planbuilder.New(ctx, a.Catalog, nil) + builder.SetParserOptions(sqlMode.ParserOptions()) + parsedTrigger, _, _, _, err = builder.Parse(trigger.CreateStatement, nil, false) if err != nil { // We want to be able to drop invalid triggers, so ignore any parser errors and return the name of the trigger if !ignoreParseErrors { diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index dc256f330e..f78f2dea1f 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -218,7 +218,7 @@ func TestPushNotFilters(t *testing.T) { ctx := sql.NewContext(context.Background(), sql.WithSession(sess)) ctx.SetCurrentDatabase("mydb") - b := planbuilder.New(ctx, cat, nil, nil) + b := planbuilder.New(ctx, cat, nil) for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { diff --git a/sql/analyzer/rule_ids.go b/sql/analyzer/rule_ids.go index 56bb1dc75e..d68f1ab21a 100644 --- a/sql/analyzer/rule_ids.go +++ b/sql/analyzer/rule_ids.go @@ -86,4 +86,7 @@ const ( cacheSubqueryAliasesInJoinsId // cacheSubqueryAliasesInJoins QuoteDefaultColumnValueNamesId // quoteDefaultColumnValueNames TrackProcessId // trackProcess + + // extra that needs to be added to once before + engineOverridesId // engineOverrides ) diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 54ae3a644c..36838affd9 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -31,6 +31,7 @@ func init() { // OnceBeforeDefault contains the rules to be applied just once before the // DefaultRules. var OnceBeforeDefault = []Rule{ + {Id: engineOverridesId, Apply: engineOverrides}, {Id: applyDefaultSelectLimitId, Apply: applyDefaultSelectLimit}, {Id: replaceCountStarId, Apply: replaceCountStar}, {Id: validateOffsetAndLimitId, Apply: validateOffsetAndLimit}, diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index cdbc877223..8f6a3f3025 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -210,7 +210,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, return nil, transform.SameTree, err } - b := planbuilder.New(ctx, a.Catalog, nil, nil) + b := planbuilder.New(ctx, a.Catalog, nil) b.DisableAuth() prevActive := b.TriggerCtx().Active b.TriggerCtx().Active = true diff --git a/sql/analyzer/vector_index_test.go b/sql/analyzer/vector_index_test.go index 76f51e75db..f2e8944ff8 100644 --- a/sql/analyzer/vector_index_test.go +++ b/sql/analyzer/vector_index_test.go @@ -125,7 +125,7 @@ func TestVectorIndex(t *testing.T) { "expected:\n%s,\nfound:\n%s\n", testCase.expectedPlan, res.String()) } - iter, err := rowexec.DefaultBuilder.Build(ctx, res, nil) + iter, err := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(ctx, res, nil) require.NoError(t, err) rows, err = sql.RowIterToRows(ctx, iter) require.NoError(t, err) @@ -157,7 +157,7 @@ func TestShowCreateTableWithVectorIndex(t *testing.T) { &vectorIndex, } - rowIter, _ := rowexec.DefaultBuilder.Build(ctx, showCreateTable, nil) + rowIter, _ := rowexec.NewBuilder(nil, sql.EngineOverrides{}).Build(ctx, showCreateTable, nil) row, err := rowIter.Next(ctx) diff --git a/sql/catalog.go b/sql/catalog.go index b3d078e653..edb99ce860 100644 --- a/sql/catalog.go +++ b/sql/catalog.go @@ -47,6 +47,9 @@ type Catalog interface { // AuthorizationHandler returns the AuthorizationHandler that is used by the catalog. AuthorizationHandler() AuthorizationHandler + + // Overrides returns the overrides that replace various functionality within the engine. + Overrides() EngineOverrides } // CatalogTable is a Table that depends on a Catalog. diff --git a/sql/catalog_map.go b/sql/catalog_map.go index 4191d48f1f..566fc44202 100644 --- a/sql/catalog_map.go +++ b/sql/catalog_map.go @@ -152,3 +152,7 @@ func (t MapCatalog) DropDbStats(ctx *Context, db string, flush bool) error { func (t MapCatalog) AuthorizationHandler() AuthorizationHandler { return GetAuthorizationHandlerFactory().CreateHandler(t) } + +func (MapCatalog) Overrides() EngineOverrides { + return EngineOverrides{} +} diff --git a/sql/functions.go b/sql/functions.go index b0a804f3c4..f9bbec7d41 100644 --- a/sql/functions.go +++ b/sql/functions.go @@ -108,6 +108,7 @@ func NewFunction0(name string, fn func() Expression) Function0 { } } +// NewInstance implements the interface Function. func (fn Function0) NewInstance(args []Expression) (Expression, error) { if len(args) != 0 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 0, len(args)) @@ -116,6 +117,7 @@ func (fn Function0) NewInstance(args []Expression) (Expression, error) { return fn.Fn(), nil } +// NewInstance implements the interface Function. func (fn Function1) NewInstance(args []Expression) (Expression, error) { if len(args) != 1 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 1, len(args)) @@ -124,6 +126,7 @@ func (fn Function1) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0]), nil } +// NewInstance implements the interface Function. func (fn Function2) NewInstance(args []Expression) (Expression, error) { if len(args) != 2 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 2, len(args)) @@ -132,6 +135,7 @@ func (fn Function2) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1]), nil } +// NewInstance implements the interface Function. func (fn Function3) NewInstance(args []Expression) (Expression, error) { if len(args) != 3 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 3, len(args)) @@ -140,6 +144,7 @@ func (fn Function3) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2]), nil } +// NewInstance implements the interface Function. func (fn Function4) NewInstance(args []Expression) (Expression, error) { if len(args) != 4 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 4, len(args)) @@ -148,6 +153,7 @@ func (fn Function4) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3]), nil } +// NewInstance implements the interface Function. func (fn Function5) NewInstance(args []Expression) (Expression, error) { if len(args) != 5 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 5, len(args)) @@ -156,6 +162,7 @@ func (fn Function5) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3], args[4]), nil } +// NewInstance implements the interface Function. func (fn Function6) NewInstance(args []Expression) (Expression, error) { if len(args) != 6 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 6, len(args)) @@ -164,6 +171,7 @@ func (fn Function6) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil } +// NewInstance implements the interface Function. func (fn Function7) NewInstance(args []Expression) (Expression, error) { if len(args) != 7 { return nil, ErrInvalidArgumentNumber.New(fn.Name, 7, len(args)) @@ -172,6 +180,7 @@ func (fn Function7) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil } +// NewInstance implements the interface Function. func (fn FunctionN) NewInstance(args []Expression) (Expression, error) { return fn.Fn(args...) } diff --git a/sql/index_builder_test.go b/sql/index_builder_test.go index c055768fb3..8246b59293 100644 --- a/sql/index_builder_test.go +++ b/sql/index_builder_test.go @@ -15,7 +15,6 @@ package sql_test import ( - "context" "fmt" "testing" @@ -26,7 +25,7 @@ import ( ) func TestIndexBuilderRanges(t *testing.T) { - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() t.Run("None=[NULL,Inf)", func(t *testing.T) { builder := sql.NewMySQLIndexBuilder(testIndex{1}) diff --git a/sql/information_schema/columns_table.go b/sql/information_schema/columns_table.go index 05901ce819..ac367db7cc 100644 --- a/sql/information_schema/columns_table.go +++ b/sql/information_schema/columns_table.go @@ -374,7 +374,8 @@ func getRowsFromViews(ctx *sql.Context, catalog sql.Catalog, db DbWithNames, pri privSetDb := privSet.Database(db.Database.Name()) for _, view := range views { // TODO: figure out how auth works in this case - node, _, err := planbuilder.Parse(ctx, catalog, view.CreateViewStatement) + builder := planbuilder.New(ctx, catalog, nil) + node, _, _, _, err := builder.Parse(view.CreateViewStatement, nil, false) if err != nil { continue // sometimes views contains views from other databases } diff --git a/sql/information_schema/information_schema.go b/sql/information_schema/information_schema.go index e2fd7a262d..adb837bb39 100644 --- a/sql/information_schema/information_schema.go +++ b/sql/information_schema/information_schema.go @@ -1951,7 +1951,9 @@ func triggersRowIter(ctx *Context, c Catalog) (RowIter, error) { ctx.SetCurrentDatabase(db.Database.Name()) triggerSqlMode := NewSqlModeFromString(trigger.SqlMode) // TODO: figure out how auth works in this case - parsedTrigger, _, err := planbuilder.ParseWithOptions(ctx, c, trigger.CreateStatement, triggerSqlMode.ParserOptions()) + builder := planbuilder.New(ctx, c, nil) + builder.SetParserOptions(triggerSqlMode.ParserOptions()) + parsedTrigger, _, _, _, err := builder.Parse(trigger.CreateStatement, nil, false) if err != nil { return nil, err } diff --git a/sql/information_schema/routines_table.go b/sql/information_schema/routines_table.go index 0d03bf4137..24bfe960a9 100644 --- a/sql/information_schema/routines_table.go +++ b/sql/information_schema/routines_table.go @@ -155,7 +155,8 @@ func routinesRowIter(ctx *Context, c Catalog, p map[string][]*plan.Procedure) (R // todo shortcircuit routineDef->procedure.CreateProcedureString? // TODO: figure out how auth works in this case - parsedProcedure, _, err := planbuilder.Parse(ctx, c, procedure.CreateProcedureString) + builder := planbuilder.New(ctx, c, nil) + parsedProcedure, _, _, _, err := builder.Parse(procedure.CreateProcedureString, nil, false) if err != nil { continue } diff --git a/sql/information_schema/views_table.go b/sql/information_schema/views_table.go index a9ac29661f..1b9409b340 100644 --- a/sql/information_schema/views_table.go +++ b/sql/information_schema/views_table.go @@ -83,7 +83,9 @@ func viewsRowIter(ctx *Context, catalog Catalog) (RowIter, error) { continue } // TODO: figure out how auth works in this case - parsedView, _, err := planbuilder.ParseWithOptions(ctx, catalog, view.CreateViewStatement, NewSqlModeFromString(view.SqlMode).ParserOptions()) + builder := planbuilder.New(ctx, catalog, nil) + builder.SetParserOptions(NewSqlModeFromString(view.SqlMode).ParserOptions()) + parsedView, _, _, _, err := builder.Parse(view.CreateViewStatement, nil, false) if err != nil { continue } diff --git a/sql/overrides.go b/sql/overrides.go new file mode 100644 index 0000000000..7da42cde1e --- /dev/null +++ b/sql/overrides.go @@ -0,0 +1,120 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +// NodeOverriding is a Node that makes use of functionality that may be overridden. +type NodeOverriding interface { + Node + WithOverrides(overrides EngineOverrides) Node +} + +// ExpressionOverriding is a Node that makes use of functionality that may be overridden. +type ExpressionOverriding interface { + Expression + WithOverrides(overrides EngineOverrides) Expression +} + +// EngineOverrides contains functions and variables that can replace, supplement, or override functionality within the +// various engine phases (such as the analysis, node execution, etc.). The empty struct is valid, which will not +// override any functionality (uses the default MySQL functionality for all applicable situations). +type EngineOverrides struct { + // Builder contains functions and variables that can replace, supplement, or override functionality within the builder. + Builder BuilderOverrides + // SchemaFormatter is the formatter for schema string creation. If nil, this will format in MySQL's style. + SchemaFormatter SchemaFormatter + // Hooks contain various hooks that are called within a statement's lifecycle. + Hooks ExecutionHooks +} + +// BuilderOverrides contains functions and variables that can replace, supplement, or override functionality within the +// builder. +type BuilderOverrides struct { + // When this is non-nil, then this allows for table names to be used in the same context as column names. When a + // table name creates a match, then this function is called to create an expression. The return value of the created + // expression will be used in place of the `GetField` expression used for columns. The input `fields` contains the + // `GetField` expressions for all of the table's columns. For standard MySQL compatibility, this should be nil. + ParseTableAsColumn func(fields []Expression) Expression + // Represents the parser to use. If this is nil, then the MySQL parser will be used. + Parser Parser +} + +// ExecutionHooks contain various hooks that are called within a statement's lifecycle. Each inner struct represents a +// specific statement. +type ExecutionHooks struct { + CreateTable CreateTable // CreateTable contains hooks related to CREATE TABLE statements. + RenameTable RenameTable // RenameTable contains hooks related to RENAME TABLE statements. + DropTable DropTable // DropTable contains hooks related to DROP TABLE statements. + TableAddColumn TableAddColumn // TableAddColumn contains hooks related to ALTER TABLE ... ADD COLUMN statements. + TableRenameColumn TableRenameColumn // TableRenameColumn contains hooks related to ALTER TABLE ... RENAME COLUMN statements. + TableModifyColumn TableModifyColumn // TableModifyColumn contains hooks related to ALTER TABLE ... MODIFY COLUMN statements. + TableDropColumn TableDropColumn // TableDropColumn contains hooks related to ALTER TABLE ... DROP COLUMN statements. +} + +// CreateTable contains hooks related to CREATE TABLE statements. These will take a *plan.CreateTable. +type CreateTable struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} + +// RenameTable contains hooks related to RENAME TABLE statements. These will take a *plan.RenameTable. +type RenameTable struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} + +// DropTable contains hooks related to DROP TABLE statements. These will take a *plan.DropTable. +type DropTable struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} + +// TableAddColumn contains hooks related to ALTER TABLE ... ADD COLUMN statements. These will take a *plan.AddColumn. +type TableAddColumn struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} + +// TableRenameColumn contains hooks related to ALTER TABLE ... RENAME COLUMN statements. These will take a *plan.RenameColumn. +type TableRenameColumn struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} + +// TableModifyColumn contains hooks related to ALTER TABLE ... MODIFY COLUMN statements. These will take a +// *plan.ModifyColumn. +type TableModifyColumn struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} + +// TableDropColumn contains hooks related to ALTER TABLE ... DROP COLUMN statements. These will take a *plan.DropColumn. +type TableDropColumn struct { + // PreSQLExecution is called before the final step of statement execution, after analysis. + PreSQLExecution func(*Context, Node) (Node, error) + // PostSQLExecution is called after the final step of statement execution, after analysis. + PostSQLExecution func(*Context, Node) error +} diff --git a/sql/parser.go b/sql/parser.go index f23ae02dc1..539b381b48 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -24,13 +24,11 @@ import ( ast "github.com/dolthub/vitess/go/vt/sqlparser" ) -// GlobalParser is a temporary variable to expose Doltgres parser. -// It defaults to MysqlParser. -var GlobalParser Parser = NewMysqlParser() +// DefaultMySQLParser is the default MySQL parser. +var DefaultMySQLParser Parser = NewMysqlParser() -// GlobalSchemaFormatter is a temporary variable to expose Doltgres schema formatter. -// It defaults to MySqlSchemaFormatter. -var GlobalSchemaFormatter SchemaFormatter = &MySqlSchemaFormatter{} +// DefaultMySQLSchemaFormatter is the default MySQL schema formatter. +var DefaultMySQLSchemaFormatter SchemaFormatter = &MySqlSchemaFormatter{} // Parser knows how to transform a SQL statement into an AST type Parser interface { diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index d32b3cf0c1..9ec0d1e558 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -150,7 +150,7 @@ func (d DropCheck) String() string { return pr.String() } -func NewCheckDefinition(ctx *sql.Context, check *sql.CheckConstraint) (*sql.CheckDefinition, error) { +func NewCheckDefinition(ctx *sql.Context, check *sql.CheckConstraint, schemaFormatter sql.SchemaFormatter) (*sql.CheckDefinition, error) { // When transforming an analyzed CheckConstraint into a CheckDefinition (for storage), we strip off any table // qualifiers that got resolved during analysis. This is to naively match the MySQL behavior, which doesn't print // any table qualifiers in check expressions. @@ -158,7 +158,7 @@ func NewCheckDefinition(ctx *sql.Context, check *sql.CheckConstraint) (*sql.Chec gf, ok := e.(*expression.GetField) if ok { newGf := expression.NewGetField(gf.Index(), gf.Type(), gf.Name(), gf.IsNullable()) - newGf = newGf.WithQuotedNames(sql.GlobalSchemaFormatter, true) + newGf = newGf.WithQuotedNames(schemaFormatter, true) return newGf, transform.NewTree, nil } return e, transform.SameTree, nil diff --git a/sql/plan/alter_table.go b/sql/plan/alter_table.go index 2f45203ccd..aafcf6e1fa 100644 --- a/sql/plan/alter_table.go +++ b/sql/plan/alter_table.go @@ -59,30 +59,6 @@ func (r *RenameTable) IsReadOnly() bool { return false } -func (r *RenameTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - renamer, _ := r.Db.(sql.TableRenamer) - viewDb, _ := r.Db.(sql.ViewDatabase) - viewRegistry := ctx.GetViewRegistry() - - for i, oldName := range r.OldNames { - if tbl, exists := r.tableExists(ctx, oldName); exists { - err := r.renameTable(ctx, renamer, tbl, oldName, r.NewNames[i]) - if err != nil { - return nil, err - } - } else { - success, err := r.renameView(ctx, viewDb, viewRegistry, oldName, r.NewNames[i]) - if err != nil { - return nil, err - } else if !success { - return nil, sql.ErrTableNotFound.New(oldName) - } - } - } - - return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil -} - func (r *RenameTable) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(r, children...) } @@ -92,7 +68,7 @@ func (*RenameTable) CollationCoercibility(ctx *sql.Context) (collation sql.Colla return sql.Collation_binary, 7 } -func (r *RenameTable) tableExists(ctx *sql.Context, name string) (sql.Table, bool) { +func (r *RenameTable) TableExists(ctx *sql.Context, name string) (sql.Table, bool) { tbl, ok, err := r.Db.GetTableInsensitive(ctx, name) if err != nil || !ok { return nil, false @@ -100,7 +76,7 @@ func (r *RenameTable) tableExists(ctx *sql.Context, name string) (sql.Table, boo return tbl, true } -func (r *RenameTable) renameTable(ctx *sql.Context, renamer sql.TableRenamer, tbl sql.Table, oldName, newName string) error { +func (r *RenameTable) RenameTable(ctx *sql.Context, renamer sql.TableRenamer, tbl sql.Table, oldName, newName string) error { if renamer == nil { return sql.ErrRenameTableNotSupported.New(r.Db.Name()) } @@ -160,7 +136,7 @@ func (r *RenameTable) renameTable(ctx *sql.Context, renamer sql.TableRenamer, tb return nil } -func (r *RenameTable) renameView(ctx *sql.Context, viewDb sql.ViewDatabase, vr *sql.ViewRegistry, oldName, newName string) (bool, error) { +func (r *RenameTable) RenameView(ctx *sql.Context, viewDb sql.ViewDatabase, vr *sql.ViewRegistry, oldName, newName string) (bool, error) { if viewDb != nil { oldView, exists, err := viewDb.GetViewDefinition(ctx, oldName) if err != nil { diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index f95fad7a05..37e7021632 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -406,14 +406,14 @@ func (c *CreateTable) WithParentForeignKeyTables(refTbls []sql.ForeignKeyTable) } // CreateChecks creates the check constraints on the table. -func (c *CreateTable) CreateChecks(ctx *sql.Context, tableNode sql.Table) error { +func (c *CreateTable) CreateChecks(ctx *sql.Context, tableNode sql.Table, schemaFormatter sql.SchemaFormatter) error { chAlterable, ok := tableNode.(sql.CheckAlterableTable) if !ok { return ErrNoCheckConstraintSupport.New(c.name) } for _, ch := range c.checks { - check, err := NewCheckDefinition(ctx, ch) + check, err := NewCheckDefinition(ctx, ch, schemaFormatter) if err != nil { return err } diff --git a/sql/planbuilder/builder.go b/sql/planbuilder/builder.go index f3b744a3ea..1fd0d26735 100644 --- a/sql/planbuilder/builder.go +++ b/sql/planbuilder/builder.go @@ -17,7 +17,6 @@ package planbuilder import ( "fmt" "strings" - "sync" ast "github.com/dolthub/vitess/go/vt/sqlparser" @@ -28,10 +27,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" ) -var BinderFactory = &sync.Pool{New: func() interface{} { - return &Builder{f: &factory{}} -}} - type Builder struct { // EventScheduler is used to communicate with the event scheduler // for any EVENT related statements. It can be nil if EventScheduler is not defined. @@ -58,6 +53,7 @@ type Builder struct { multiDDL bool insertActive bool parserOpts ast.ParserOptions + overrides sql.BuilderOverrides } // BindvarContext holds bind variable replacement literals. @@ -112,15 +108,18 @@ type ProcContext struct { DbName string } -// New takes ctx, catalog, event scheduler, and parser. If the parser is nil, then default parser is mysql parser. -func New(ctx *sql.Context, cat sql.Catalog, es sql.EventScheduler, p sql.Parser) *Builder { - if p == nil { - p = sql.GlobalParser - } - +// New takes ctx, catalog, event scheduler, and parser. If the parser is nil, then the default parser is used (which +// will be the MySQL parser unless modified). +func New(ctx *sql.Context, cat sql.Catalog, es sql.EventScheduler) *Builder { var state sql.AuthorizationQueryState + var overrides sql.BuilderOverrides + var p = sql.DefaultMySQLParser if cat != nil { state = cat.AuthorizationHandler().NewQueryState(ctx) + overrides = cat.Overrides().Builder + if overrides.Parser != nil { + p = overrides.Parser + } } return &Builder{ ctx: ctx, @@ -132,6 +131,7 @@ func New(ctx *sql.Context, cat sql.Catalog, es sql.EventScheduler, p sql.Parser) qFlags: &sql.QueryFlags{}, authEnabled: true, authQueryState: state, + overrides: overrides, } } diff --git a/sql/planbuilder/parse.go b/sql/planbuilder/parse.go index b964cc1911..99e221e5a9 100644 --- a/sql/planbuilder/parse.go +++ b/sql/planbuilder/parse.go @@ -32,19 +32,6 @@ const maxAnalysisIterations = 8 // ErrMaxAnalysisIters is thrown when the analysis iterations are exceeded var ErrMaxAnalysisIters = errors.NewKind("exceeded max analysis iterations (%d)") -// Parse parses the given SQL |query| using the default parsing settings and returns the corresponding node. -func Parse(ctx *sql.Context, cat sql.Catalog, query string) (sql.Node, *sql.QueryFlags, error) { - return ParseWithOptions(ctx, cat, query, sql.LoadSqlMode(ctx).ParserOptions()) -} - -func ParseWithOptions(ctx *sql.Context, cat sql.Catalog, query string, options ast.ParserOptions) (sql.Node, *sql.QueryFlags, error) { - // TODO: need correct parser - b := New(ctx, cat, nil, nil) - b.SetParserOptions(options) - node, _, _, qFlags, err := b.Parse(query, nil, false) - return node, qFlags, err -} - func (b *Builder) Parse(query string, qFlags *sql.QueryFlags, multi bool) (ret sql.Node, parsed, remainder string, qProps *sql.QueryFlags, err error) { defer trace.StartRegion(b.ctx, "ParseOnly").End() b.nesting++ diff --git a/sql/planbuilder/parse_column_default.go b/sql/planbuilder/parse_column_default.go index e095b54e9e..672686e5a2 100644 --- a/sql/planbuilder/parse_column_default.go +++ b/sql/planbuilder/parse_column_default.go @@ -43,7 +43,9 @@ func StringToColumnDefaultValue(ctx *sql.Context, exprStr string) (*sql.ColumnDe if !ok { return nil, fmt.Errorf("DefaultStringToExpression expected *sqlparser.AliasedExpr but received %T", parserSelect.SelectExprs[0]) } - proj, _, err := Parse(ctx, nil, fmt.Sprintf("SELECT %s", aliasedExpr.Expr)) + // TODO: this needs to take a catalog + builder := New(ctx, nil, nil) + proj, _, _, _, err := builder.Parse(fmt.Sprintf("SELECT %s", aliasedExpr.Expr), nil, false) if err != nil { return nil, err } diff --git a/sql/planbuilder/parse_test.go b/sql/planbuilder/parse_test.go index 4159f5aeb0..8e6a1f9bc2 100644 --- a/sql/planbuilder/parse_test.go +++ b/sql/planbuilder/parse_test.go @@ -2646,7 +2646,7 @@ Project ctx := sql.NewContext(context.Background(), sql.WithSession(sess)) ctx.SetCurrentDatabase("mydb") - b := New(ctx, cat, nil, nil) + b := New(ctx, cat, nil) for _, tt := range tests { t.Run(tt.Query, func(t *testing.T) { @@ -3041,7 +3041,7 @@ func TestPlanBuilderErr(t *testing.T) { ctx := sql.NewContext(context.Background(), sql.WithSession(sess)) ctx.SetCurrentDatabase("mydb") - b := New(ctx, cat, nil, nil) + b := New(ctx, cat, nil) for _, tt := range tests { t.Run(tt.Query, func(t *testing.T) { diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index c2257fc421..2f3ba0fa65 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -238,7 +238,7 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, } } }() - b := New(ctx, cat, nil, nil) + b := New(ctx, cat, nil) b.DisableAuth() b.SetParserOptions(sql.NewSqlModeFromString(procDetails.SqlMode).ParserOptions()) if asOf != nil { diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index e352643161..ce35f8fef6 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -136,19 +136,42 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { return sysVar } } - var err error if scope == ast.SetScope_User || scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { - err = sql.ErrUnknownUserVariable.New(colName) + err := sql.ErrUnknownUserVariable.New(colName) + b.handleErr(err) } else if scope == ast.SetScope_Global || scope == ast.SetScope_Session { - err = sql.ErrUnknownSystemVariable.New(colName) + err := sql.ErrUnknownSystemVariable.New(colName) + b.handleErr(err) } else if tblName != "" && !inScope.hasTable(tblName) { - err = sql.ErrTableNotFound.New(tblName) + err := sql.ErrTableNotFound.New(tblName) + b.handleErr(err) } else if tblName != "" { - err = sql.ErrTableColumnNotFound.New(tblName, colName) + err := sql.ErrTableColumnNotFound.New(tblName, colName) + b.handleErr(err) + } else if b.overrides.ParseTableAsColumn != nil && inScope.hasTable(colName) { + scopeTableCols := inScope.resolveColumnAsTable(dbName, colName) + if len(scopeTableCols) == 0 { + err := sql.ErrColumnNotFound.New(v) + b.handleErr(err) + } + astQualifier := ast.TableName{ + Name: ast.NewTableIdent(colName), // This must be `colName` due to table aliases + DbQualifier: ast.NewTableIdent(scopeTableCols[0].db), + } + fieldArgs := make([]sql.Expression, len(scopeTableCols)) + for i := range scopeTableCols { + astArg := ast.ColName{ + StoredProcVal: nil, + Qualifier: astQualifier, + Name: ast.NewColIdent(scopeTableCols[i].col), + } + fieldArgs[i] = b.buildScalar(inScope, &astArg) + } + return b.overrides.ParseTableAsColumn(fieldArgs) } else { - err = sql.ErrColumnNotFound.New(v) + err := sql.ErrColumnNotFound.New(v) + b.handleErr(err) } - b.handleErr(err) } origTbl := b.getOrigTblName(inScope.node, c.table) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 75efead05d..cf97c0dfaa 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -15,6 +15,7 @@ package planbuilder import ( + "sort" "strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" @@ -153,6 +154,26 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo return c, true } +// resolveColumnAsTable resolves a column as though it were a table, by searching the table space. This then returns all +// columns (in their index order) of the table. +func (s *scope) resolveColumnAsTable(db, table string) []scopeColumn { + var tableCols []scopeColumn + tabId := s.getTable(table) + for _, col := range s.cols { + if col.tableId != tabId || (db != "" && !strings.EqualFold(col.db, db)) { + continue + } + tableCols = append(tableCols, col) + } + if len(tableCols) == 0 && s.parent != nil { + return s.parent.resolveColumnAsTable(db, table) + } + sort.Slice(tableCols, func(i, j int) bool { + return tableCols[i].id < tableCols[j].id + }) + return tableCols +} + // getCol gets a scopeColumn based on a columnId func (s *scope) getCol(colId sql.ColumnId) (scopeColumn, bool) { if s.colset.Contains(colId) { @@ -176,6 +197,18 @@ func (s *scope) hasTable(table string) bool { return false } +// getTable returns the table ID matching the given name. +func (s *scope) getTable(table string) sql.TableId { + id, ok := s.tables[strings.ToLower(table)] + if ok { + return id + } + if s.parent != nil { + return s.parent.getTable(table) + } + return 0 +} + // triggerCol is used to hallucinate a new column during trigger DDL // when we fail a resolveColumn. func (s *scope) triggerCol(table, col string) (scopeColumn, bool) { diff --git a/sql/rowexec/builder.go b/sql/rowexec/builder.go index 9133b1c443..a76e9c06ae 100644 --- a/sql/rowexec/builder.go +++ b/sql/rowexec/builder.go @@ -21,29 +21,37 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" ) -var DefaultBuilder = &BaseBuilder{} - -var _ sql.NodeExecBuilder = (*BaseBuilder)(nil) - -type ExecBuilderFunc func(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter, error) - // BaseBuilder converts a plan tree into a RowIter tree. All relational nodes // have a build statement. Custom source nodes that provide rows that implement // sql.ExecSourceRel are also built into the tree. type BaseBuilder struct { - // if override is provided, we try to build executor with this first - override sql.NodeExecBuilder + PriorityBuilder sql.NodeExecBuilder + EngineOverrides sql.EngineOverrides + schemaFormatter sql.SchemaFormatter +} + +var _ sql.NodeExecBuilder = (*BaseBuilder)(nil) + +// NewBuilder creates a new builder. If a priority builder is given, then it is tried first, and only uses the internal +// builder logic if the given one does not return a result (and does not error). +func NewBuilder(priority sql.NodeExecBuilder, overrides sql.EngineOverrides) sql.NodeExecBuilder { + schemaFormatter := sql.DefaultMySQLSchemaFormatter + if overrides.SchemaFormatter != nil { + schemaFormatter = overrides.SchemaFormatter + } + return &BaseBuilder{ + PriorityBuilder: priority, + EngineOverrides: overrides, + schemaFormatter: schemaFormatter, + } } +// Build implements the interface sql.NodeExecBuilder. func (b *BaseBuilder) Build(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter, error) { defer trace.StartRegion(ctx, "ExecBuilder.Build").End() return b.buildNodeExec(ctx, n, r) } -func NewOverrideBuilder(override sql.NodeExecBuilder) sql.NodeExecBuilder { - return &BaseBuilder{override: override} -} - // FinalizeIters applies the final transformations on sql.RowIter before execution. func FinalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) (sql.RowIter, sql.Schema, error) { var sch sql.Schema diff --git a/sql/rowexec/common_test.go b/sql/rowexec/common_test.go index 667b1a9ba0..a59fba6d62 100644 --- a/sql/rowexec/common_test.go +++ b/sql/rowexec/common_test.go @@ -29,6 +29,8 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) +var DefaultBuilder = NewBuilder(nil, sql.EngineOverrides{}).(*BaseBuilder) + func newContext(provider *memory.DbProvider) *sql.Context { return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider))) } diff --git a/sql/rowexec/create_view_test.go b/sql/rowexec/create_view_test.go index c65d3afcab..538ffc0239 100644 --- a/sql/rowexec/create_view_test.go +++ b/sql/rowexec/create_view_test.go @@ -56,7 +56,7 @@ func TestCreateViewWithRegistry(t *testing.T) { createView := newCreateView(memory.NewViewlessDatabase("mydb"), false, false) - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() _, err := DefaultBuilder.buildNodeExec(ctx, createView, nil) require.NoError(err) @@ -71,16 +71,16 @@ func TestCreateExistingViewNative(t *testing.T) { createView := newCreateView(memory.NewDatabase("mydb"), false, false) createExistingView := newCreateView(memory.NewDatabase("mydb"), true, false) - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() _, err := DefaultBuilder.buildNodeExec(ctx, createView, nil) require.NoError(t, err) - ctx = sql.NewContext(context.Background()) + ctx = sql.NewEmptyContext() _, err = DefaultBuilder.buildNodeExec(ctx, createView, nil) require.Error(t, err) require.True(t, sql.ErrExistingView.Is(err)) - ctx = sql.NewContext(context.Background()) + ctx = sql.NewEmptyContext() _, err = DefaultBuilder.buildNodeExec(ctx, createExistingView, nil) require.NoError(t, err) } @@ -91,7 +91,7 @@ func TestReplaceExistingViewNative(t *testing.T) { db := memory.NewDatabase("mydb") createView := newCreateView(db, false, false) - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() _, err := DefaultBuilder.buildNodeExec(ctx, createView, nil) require.NoError(t, err) @@ -131,7 +131,7 @@ func TestCreateViewNative(t *testing.T) { db := memory.NewDatabase("mydb") createView := newCreateView(db, false, false) - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() _, err := DefaultBuilder.buildNodeExec(ctx, createView, nil) require.NoError(t, err) diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index 483defa5e4..e3d94558be 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -240,10 +240,51 @@ func (b *BaseBuilder) buildDropCheck(ctx *sql.Context, n *plan.DropCheck, row sq } func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, row sql.Row) (sql.RowIter, error) { - return n.RowIter(ctx, row) + if b.EngineOverrides.Hooks.RenameTable.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.RenameTable.PreSQLExecution(ctx, n) + if err != nil { + return nil, err + } + n = nn.(*plan.RenameTable) + } + + renamer, _ := n.Db.(sql.TableRenamer) + viewDb, _ := n.Db.(sql.ViewDatabase) + viewRegistry := ctx.GetViewRegistry() + + for i, oldName := range n.OldNames { + if tbl, exists := n.TableExists(ctx, oldName); exists { + err := n.RenameTable(ctx, renamer, tbl, oldName, n.NewNames[i]) + if err != nil { + return nil, err + } + } else { + success, err := n.RenameView(ctx, viewDb, viewRegistry, oldName, n.NewNames[i]) + if err != nil { + return nil, err + } else if !success { + return nil, sql.ErrTableNotFound.New(oldName) + } + } + } + if b.EngineOverrides.Hooks.RenameTable.PostSQLExecution != nil { + if err := b.EngineOverrides.Hooks.RenameTable.PostSQLExecution(ctx, n); err != nil { + return nil, err + } + } + + return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil } func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn, row sql.Row) (sql.RowIter, error) { + if b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution(ctx, n) + if err != nil { + return nil, err + } + n = nn.(*plan.ModifyColumn) + } + tbl, err := getTableFromDatabase(ctx, n.Database(), n.Table) if err != nil { return nil, err @@ -279,6 +320,7 @@ func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn, return &modifyColumnIter{ m: n, alterable: alterable, + overrides: b.EngineOverrides, }, nil } @@ -841,6 +883,14 @@ func (b *BaseBuilder) buildDropDB(ctx *sql.Context, n *plan.DropDB, row sql.Row) } func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn, row sql.Row) (sql.RowIter, error) { + if b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution(ctx, n) + if err != nil { + return nil, err + } + n = nn.(*plan.RenameColumn) + } + tbl, err := getTableFromDatabase(ctx, n.Database(), n.Table) if err != nil { return nil, err @@ -881,11 +931,26 @@ func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn, } } } + if err = alterable.ModifyColumn(ctx, n.ColumnName, col, nil); err != nil { + return nil, err + } + if b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution != nil { + if err = b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution(ctx, n); err != nil { + return nil, err + } + } - return rowIterWithOkResultWithZeroRowsAffected(), alterable.ModifyColumn(ctx, n.ColumnName, col, nil) + return rowIterWithOkResultWithZeroRowsAffected(), nil } func (b *BaseBuilder) buildAddColumn(ctx *sql.Context, n *plan.AddColumn, row sql.Row) (sql.RowIter, error) { + if b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution(ctx, n) + if err != nil { + return nil, err + } + n = nn.(*plan.AddColumn) + } table, err := getTableFromDatabase(ctx, n.Database(), n.Table) if err != nil { return nil, err @@ -963,6 +1028,13 @@ func (b *BaseBuilder) buildAlterDB(ctx *sql.Context, n *plan.AlterDB, row sql.Ro func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, row sql.Row) (sql.RowIter, error) { var err error + if b.EngineOverrides.Hooks.CreateTable.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.CreateTable.PreSQLExecution(ctx, n) + if err != nil { + return sql.RowsToRowIter(), err + } + n = nn.(*plan.CreateTable) + } // If it's set to Invalid, then no collation has been explicitly defined if n.Collation == sql.Collation_Unspecified { @@ -1116,12 +1188,18 @@ func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, ro } if len(n.Checks()) > 0 { - err = n.CreateChecks(ctx, tableNode) + err = n.CreateChecks(ctx, tableNode, b.schemaFormatter) if err != nil { return sql.RowsToRowIter(), err } } + if b.EngineOverrides.Hooks.CreateTable.PostSQLExecution != nil { + if err = b.EngineOverrides.Hooks.CreateTable.PostSQLExecution(ctx, n); err != nil { + return nil, err + } + } + return rowIterWithOkResultWithZeroRowsAffected(), nil } @@ -1199,6 +1277,13 @@ func (b *BaseBuilder) buildCreateTrigger(ctx *sql.Context, n *plan.CreateTrigger } func (b *BaseBuilder) buildDropColumn(ctx *sql.Context, n *plan.DropColumn, row sql.Row) (sql.RowIter, error) { + if b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution(ctx, n) + if err != nil { + return nil, err + } + n = nn.(*plan.DropColumn) + } tbl, err := getTableFromDatabase(ctx, n.Database(), n.Table) if err != nil { return nil, err @@ -1217,6 +1302,7 @@ func (b *BaseBuilder) buildDropColumn(ctx *sql.Context, n *plan.DropColumn, row return &dropColumnIter{ d: n, alterable: alterable, + overrides: b.EngineOverrides, }, nil } diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 3a58743c7c..9cd63c19dd 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -317,6 +317,7 @@ func (l *loadDataIter) parseFields(ctx *sql.Context, line string) (exprs []sql.E type modifyColumnIter struct { m *plan.ModifyColumn alterable sql.AlterableTable + overrides sql.EngineOverrides runOnce bool } @@ -422,6 +423,11 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } if rewritten { + if i.overrides.Hooks.TableModifyColumn.PostSQLExecution != nil { + if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.m); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } } @@ -441,6 +447,11 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } } + if i.overrides.Hooks.TableModifyColumn.PostSQLExecution != nil { + if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.m); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } @@ -1332,6 +1343,11 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } if rewritten { + if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil { + if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.a); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } } @@ -1349,6 +1365,11 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) { // We only need to update all table rows if the new column is non-nil if i.a.Column().Nullable && i.a.Column().Default == nil { + if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil { + if err = i.b.EngineOverrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.a); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } @@ -1357,6 +1378,11 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } + if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil { + if err = i.b.EngineOverrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.a); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } @@ -1709,6 +1735,7 @@ func (c *createTriggerIter) Close(*sql.Context) error { type dropColumnIter struct { d *plan.DropColumn alterable sql.AlterableTable + overrides sql.EngineOverrides runOnce bool } @@ -1735,6 +1762,11 @@ func (i *dropColumnIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } if rewritten { + if i.overrides.Hooks.TableDropColumn.PostSQLExecution != nil { + if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.d); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } } @@ -1757,6 +1789,11 @@ func (i *dropColumnIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } } + if i.overrides.Hooks.TableDropColumn.PostSQLExecution != nil { + if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.d); err != nil { + return nil, err + } + } return sql.NewRow(types.NewOkResult(0)), nil } @@ -1916,7 +1953,7 @@ func (b *BaseBuilder) executeCreateCheck(ctx *sql.Context, c *plan.CreateCheck) } } - check, err := plan.NewCheckDefinition(ctx, c.Check) + check, err := plan.NewCheckDefinition(ctx, c.Check, b.schemaFormatter) if err != nil { return err } diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index c1d6fa9407..86d05ed4ce 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -200,6 +200,13 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql. var err error var curdb sql.Database + if b.EngineOverrides.Hooks.DropTable.PreSQLExecution != nil { + nn, err := b.EngineOverrides.Hooks.DropTable.PreSQLExecution(ctx, n) + if err != nil { + return nil, err + } + n = nn.(*plan.DropTable) + } sortedTables, err := sortTablesByFKDependencies(ctx, n.Tables) if err != nil { return nil, err @@ -266,6 +273,12 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql. } } + if b.EngineOverrides.Hooks.DropTable.PostSQLExecution != nil { + if err = b.EngineOverrides.Hooks.DropTable.PostSQLExecution(ctx, n); err != nil { + return nil, err + } + } + return rowIterWithOkResultWithZeroRowsAffected(), nil } diff --git a/sql/rowexec/drop_view_test.go b/sql/rowexec/drop_view_test.go index 12c0308b4a..9b8bf10d1b 100644 --- a/sql/rowexec/drop_view_test.go +++ b/sql/rowexec/drop_view_test.go @@ -15,7 +15,6 @@ package rowexec import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -49,7 +48,7 @@ func setupView(t *testing.T, db memory.MemoryDatabase) (*sql.Context, *sql.View) createView := plan.NewCreateView(db, subqueryAlias.Name(), subqueryAlias, false, false, "CREATE VIEW myview AS SELECT i FROM mytable", "", "", "") - ctx := sql.NewContext(context.Background()) + ctx := sql.NewEmptyContext() _, err := DefaultBuilder.Build(ctx, createView, nil) require.NoError(t, err) diff --git a/sql/rowexec/node_builder.gen.go b/sql/rowexec/node_builder.gen.go index 10b0e81b99..465f8e0d29 100644 --- a/sql/rowexec/node_builder.gen.go +++ b/sql/rowexec/node_builder.gen.go @@ -26,8 +26,8 @@ import ( func (b *BaseBuilder) buildNodeExec(ctx *sql.Context, n sql.Node, row sql.Row) (sql.RowIter, error) { var iter sql.RowIter var err error - if b.override != nil { - iter, err = b.override.Build(ctx, n, row) + if b.PriorityBuilder != nil { + iter, err = b.PriorityBuilder.Build(ctx, n, row) } if err != nil { return nil, err diff --git a/sql/rowexec/show.go b/sql/rowexec/show.go index 351cbc828a..996d9e0cdf 100644 --- a/sql/rowexec/show.go +++ b/sql/rowexec/show.go @@ -725,12 +725,13 @@ func (b *BaseBuilder) buildShowIndexes(ctx *sql.Context, n *plan.ShowIndexes, ro func (b *BaseBuilder) buildShowCreateTable(ctx *sql.Context, n *plan.ShowCreateTable, row sql.Row) (sql.RowIter, error) { return &showCreateTablesIter{ - table: n.Child, - isView: n.IsView, - indexes: n.Indexes, - checks: n.Checks(), - schema: n.TargetSchema(), - pkSchema: n.PrimaryKeySchema, + table: n.Child, + isView: n.IsView, + indexes: n.Indexes, + checks: n.Checks(), + schema: n.TargetSchema(), + pkSchema: n.PrimaryKeySchema, + formatter: b.schemaFormatter, }, nil } diff --git a/sql/rowexec/show_iters.go b/sql/rowexec/show_iters.go index d6135d7368..0014537738 100644 --- a/sql/rowexec/show_iters.go +++ b/sql/rowexec/show_iters.go @@ -350,6 +350,7 @@ type showCreateTablesIter struct { checks sql.CheckConstraints didIteration bool isView bool + formatter sql.SchemaFormatter } func (i *showCreateTablesIter) Next(ctx *sql.Context) (sql.Row, error) { @@ -432,7 +433,7 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab // Statement creation parts for each column tableCollation := table.Collation() - for i, col := range schema { + for idx, col := range schema { var colDefaultStr string var err error if col.Default != nil && col.Generated == nil { @@ -452,18 +453,18 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab } if col.PrimaryKey && len(pkSchema.Schema) == 0 { - pkOrdinals = append(pkOrdinals, i) + pkOrdinals = append(pkOrdinals, idx) } - colStmts[i] = sql.GenerateCreateTableColumnDefinition(col, colDefaultStr, onUpdateStr, tableCollation) + colStmts[idx] = i.formatter.GenerateCreateTableColumnDefinition(col, colDefaultStr, onUpdateStr, tableCollation) } - for _, i := range pkOrdinals { - primaryKeyCols = append(primaryKeyCols, schema[i].Name) + for _, idx := range pkOrdinals { + primaryKeyCols = append(primaryKeyCols, schema[idx].Name) } if len(primaryKeyCols) > 0 { - colStmts = append(colStmts, sql.GenerateCreateTablePrimaryKeyDefinition(primaryKeyCols)) + colStmts = append(colStmts, i.formatter.GenerateCreateTablePrimaryKeyDefinition(primaryKeyCols)) } for _, index := range i.indexes { @@ -474,18 +475,18 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab prefixLengths := index.PrefixLengths() var indexCols []string - for i, expr := range index.Expressions() { + for idx, expr := range index.Expressions() { col := plan.GetColumnFromIndexExpr(expr, table) if col != nil { - indexDef := sql.QuoteIdentifier(col.Name) - if len(prefixLengths) > i && prefixLengths[i] != 0 { - indexDef += fmt.Sprintf("(%v)", prefixLengths[i]) + indexDef := i.formatter.QuoteIdentifier(col.Name) + if len(prefixLengths) > idx && prefixLengths[idx] != 0 { + indexDef += fmt.Sprintf("(%v)", prefixLengths[idx]) } indexCols = append(indexCols, indexDef) } } - indexDefn, shouldInclude := sql.GenerateCreateTableIndexDefinition(index.IsUnique(), index.IsSpatial(), + indexDefn, shouldInclude := i.formatter.GenerateCreateTableIndexDefinition(index.IsUnique(), index.IsSpatial(), index.IsFullText(), index.IsVector(), index.ID(), indexCols, index.Comment()) if shouldInclude { colStmts = append(colStmts, indexDefn) @@ -507,13 +508,13 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab if len(fk.OnUpdate) > 0 && fk.OnUpdate != sql.ForeignKeyReferentialAction_DefaultAction { onUpdate = string(fk.OnUpdate) } - colStmts = append(colStmts, sql.GenerateCreateTableForiegnKeyDefinition(fk.Name, fk.Columns, fk.ParentTable, fk.ParentColumns, onDelete, onUpdate)) + colStmts = append(colStmts, i.formatter.GenerateCreateTableForiegnKeyDefinition(fk.Name, fk.Columns, fk.ParentTable, fk.ParentColumns, onDelete, onUpdate)) } } if i.checks != nil { for _, check := range i.checks { - colStmts = append(colStmts, sql.GenerateCreateTableCheckConstraintClause(check.Name, check.Expr.String(), check.Enforced)) + colStmts = append(colStmts, i.formatter.GenerateCreateTableCheckConstraintClause(check.Name, check.Expr.String(), check.Enforced)) } } @@ -539,7 +540,7 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab temp = " TEMPORARY" } - return sql.GenerateCreateTableStatement(table.Name(), colStmts, temp, autoInc, table.Collation().CharacterSet().Name(), table.Collation().Name(), comment), nil + return i.formatter.GenerateCreateTableStatement(table.Name(), colStmts, temp, autoInc, table.Collation().CharacterSet().Name(), table.Collation().Name(), comment), nil } func produceCreateViewStatement(view *plan.SubqueryAlias) string { diff --git a/sql/session.go b/sql/session.go index ae850e8749..ef8a4b7ed5 100644 --- a/sql/session.go +++ b/sql/session.go @@ -420,8 +420,11 @@ func (c *Context) ApplyOpts(opts ...ContextOption) { } } -// NewEmptyContext returns a default context with default values. -func NewEmptyContext() *Context { return NewContext(context.TODO()) } +// NewEmptyContext returns a default context with default values. When an existing context is available, it is preferred +// to call ctx.NewContext to ensure that integrator-specific overrides are retained in the new context. +func NewEmptyContext() *Context { + return NewContext(context.TODO()) +} // IsInterpreted returns `true` when this is being called from within RunInterpreted. In such cases, GMS will choose to // handle logic differently, as running from within an interpreted function requires different considerations than @@ -514,6 +517,18 @@ func (c *Context) WithContext(ctx context.Context) *Context { return &nc } +// WithClient returns a new Context with the given client. +func (c *Context) WithClient(client Client) *Context { + if c == nil { + return nil + } + + nc := *c + nc.Session.SetClient(client) + nc.Session.SetPrivilegeSet(nil, 0) + return &nc +} + // RootSpan returns the root span, if any. func (c *Context) RootSpan() trace.Span { if c == nil { @@ -581,18 +596,6 @@ func (c *Context) NewErrgroup() (*errgroup.Group, *Context) { return eg, c.WithContext(egCtx) } -// NewCtxWithClient returns a new Context with the given [client] -func (c *Context) NewCtxWithClient(client Client) *Context { - if c == nil { - return nil - } - - nc := *c - nc.Session.SetClient(client) - nc.Session.SetPrivilegeSet(nil, 0) - return &nc -} - // Services are handles to optional or plugin functionality that can be // used by the SQL implementation in certain situations. An integrator can set // methods on Services for a given *Context and different parts of go-mysql-server diff --git a/sql/sqlfmt.go b/sql/sqlfmt.go index dfc9f665e4..90fcd20ae3 100644 --- a/sql/sqlfmt.go +++ b/sql/sqlfmt.go @@ -22,51 +22,47 @@ package sql // GenerateCreateTableStatement returns 'CREATE TABLE' statement with given table names // and column definition statements in order and the collation and character set names for the table func GenerateCreateTableStatement(tblName string, colStmts []string, temp, autoInc, tblCharsetName, tblCollName, comment string) string { - return GlobalSchemaFormatter.GenerateCreateTableStatement(tblName, colStmts, temp, autoInc, tblCharsetName, tblCollName, comment) + return "" } // GenerateCreateTableColumnDefinition returns column definition string for 'CREATE TABLE' statement for given column. // This part comes first in the 'CREATE TABLE' statement. func GenerateCreateTableColumnDefinition(col *Column, colDefault, onUpdate string, tableCollation CollationID) string { - return GlobalSchemaFormatter.GenerateCreateTableColumnDefinition(col, colDefault, onUpdate, tableCollation) + return "" } // GenerateCreateTablePrimaryKeyDefinition returns primary key definition string for 'CREATE TABLE' statement // for given column(s). This part comes after each column definitions. func GenerateCreateTablePrimaryKeyDefinition(pkCols []string) string { - return GlobalSchemaFormatter.GenerateCreateTablePrimaryKeyDefinition(pkCols) + return "" } // GenerateCreateTableIndexDefinition returns index definition string for 'CREATE TABLE' statement // for given index. This part comes after primary key definition if there is any. func GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector bool, indexID string, indexCols []string, comment string) (string, bool) { - return GlobalSchemaFormatter.GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector, indexID, indexCols, comment) + return "", false } // GenerateCreateTableForiegnKeyDefinition returns foreign key constraint definition string for 'CREATE TABLE' statement // for given foreign key. This part comes after index definitions if there are any. func GenerateCreateTableForiegnKeyDefinition(fkName string, fkCols []string, parentTbl string, parentCols []string, onDelete, onUpdate string) string { - return GlobalSchemaFormatter.GenerateCreateTableForiegnKeyDefinition(fkName, fkCols, parentTbl, parentCols, onDelete, onUpdate) + return "" } // GenerateCreateTableCheckConstraintClause returns check constraint clause string for 'CREATE TABLE' statement // for given check constraint. This part comes the last and after foreign key definitions if there are any. func GenerateCreateTableCheckConstraintClause(checkName, checkExpr string, enforced bool) string { - return GlobalSchemaFormatter.GenerateCreateTableCheckConstraintClause(checkName, checkExpr, enforced) + return "" } // QuoteIdentifier wraps the specified identifier in backticks and escapes all occurrences of backticks in the // identifier by replacing them with double backticks. func QuoteIdentifier(id string) string { - return GlobalSchemaFormatter.QuoteIdentifier(id) + return "" } // QuoteIdentifiers wraps each of the specified identifiers in backticks, escapes all occurrences of backticks in // the identifier, and returns a slice of the quoted identifiers. func QuoteIdentifiers(ids []string) []string { - quoted := make([]string, len(ids)) - for i, id := range ids { - quoted[i] = GlobalSchemaFormatter.QuoteIdentifier(id) - } - return quoted + return nil } diff --git a/test/test_catalog.go b/test/test_catalog.go index e7251d06ca..2fe877034a 100644 --- a/test/test_catalog.go +++ b/test/test_catalog.go @@ -217,3 +217,7 @@ func (c *Catalog) DropDbStats(ctx *sql.Context, db string, flush bool) error { func (c *Catalog) AuthorizationHandler() sql.AuthorizationHandler { return sql.GetAuthorizationHandlerFactory().CreateHandler(c) } + +func (*Catalog) Overrides() sql.EngineOverrides { + return sql.EngineOverrides{} +}