4747 unlockTablesRegex = regexp .MustCompile (`^unlock\s+tables$` )
4848 lockTablesRegex = regexp .MustCompile (`^lock\s+tables\s` )
4949 setRegex = regexp .MustCompile (`^set\s+` )
50+ createViewRegex = regexp .MustCompile (`^create\s+view\s+` )
51+ )
52+
53+ // These constants aren't exported from vitess for some reason. This could be removed if we changed this.
54+ const (
55+ colKeyNone sqlparser.ColumnKeyOption = iota
56+ colKeyPrimary
57+ colKeySpatialKey
58+ colKeyUnique
59+ colKeyUniqueKey
60+ colKey
5061)
5162
5263// Parse parses the given SQL sentence and returns the corresponding node.
@@ -93,6 +104,9 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) {
93104 return parseLockTables (ctx , s )
94105 case setRegex .MatchString (lowerQuery ):
95106 s = fixSetQuery (s )
107+ case createViewRegex .MatchString (lowerQuery ):
108+ // CREATE VIEW parses as a CREATE DDL statement with an empty table spec
109+ return nil , ErrUnsupportedFeature .New ("CREATE VIEW" )
96110 }
97111
98112 stmt , err := sqlparser .Parse (s )
@@ -144,7 +158,12 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
144158 case * sqlparser.Insert :
145159 return convertInsert (ctx , n )
146160 case * sqlparser.DDL :
147- return convertDDL (n )
161+ // unlike other statements, DDL statements have loose parsing by default
162+ ddl , err := sqlparser .ParseStrictDDL (query )
163+ if err != nil {
164+ return nil , err
165+ }
166+ return convertDDL (ddl .(* sqlparser.DDL ))
148167 case * sqlparser.Set :
149168 return convertSet (ctx , n )
150169 case * sqlparser.Use :
@@ -354,13 +373,23 @@ func convertDDL(c *sqlparser.DDL) (sql.Node, error) {
354373 switch c .Action {
355374 case sqlparser .CreateStr :
356375 return convertCreateTable (c )
376+ case sqlparser .DropStr :
377+ return convertDropTable (c )
357378 default :
358379 return nil , ErrUnsupportedSyntax .New (c )
359380 }
360381}
361382
383+ func convertDropTable (c * sqlparser.DDL ) (sql.Node , error ) {
384+ tableNames := make ([]string , len (c .FromTables ))
385+ for i , t := range c .FromTables {
386+ tableNames [i ] = t .Name .String ()
387+ }
388+ return plan .NewDropTable (sql .UnresolvedDatabase ("" ), c .IfExists , tableNames ... ), nil
389+ }
390+
362391func convertCreateTable (c * sqlparser.DDL ) (sql.Node , error ) {
363- schema , err := columnDefinitionToSchema (c .TableSpec . Columns )
392+ schema , err := tableSpecToSchema (c .TableSpec )
364393 if err != nil {
365394 return nil , err
366395 }
@@ -462,6 +491,7 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) {
462491 if err != nil {
463492 return nil , err
464493 }
494+
465495 }
466496
467497 if d .Limit != nil {
@@ -474,27 +504,56 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) {
474504 return plan .NewUpdate (node , updateExprs ), nil
475505}
476506
477- func columnDefinitionToSchema ( colDef [] * sqlparser.ColumnDefinition ) (sql.Schema , error ) {
507+ func tableSpecToSchema ( tableSpec * sqlparser.TableSpec ) (sql.Schema , error ) {
478508 var schema sql.Schema
479- for _ , cd := range colDef {
480- typ := cd .Type
481- internalTyp , err := sql .MysqlTypeToType (typ .SQLType ())
509+ for _ , cd := range tableSpec .Columns {
510+ column , err := getColumn (cd , tableSpec .Indexes )
482511 if err != nil {
483512 return nil , err
484513 }
485514
486- schema = append (schema , & sql.Column {
487- Nullable : ! bool (typ .NotNull ),
488- Type : internalTyp ,
489- Name : cd .Name .String (),
490- // TODO
491- Default : nil ,
492- })
515+ schema = append (schema , column )
493516 }
494517
495518 return schema , nil
496519}
497520
521+ // getColumn returns the sql.Column for the column definition given, as part of a create table statement.
522+ func getColumn (cd * sqlparser.ColumnDefinition , indexes []* sqlparser.IndexDefinition ) (* sql.Column , error ) {
523+ typ := cd .Type
524+ internalTyp , err := sql .MysqlTypeToType (typ .SQLType ())
525+ if err != nil {
526+ return nil , err
527+ }
528+
529+ // Primary key info can either be specified in the column's type info (for in-line declarations), or in a slice of
530+ // indexes attached to the table def. We have to check both places to find if a column is part of the primary key
531+ isPkey := cd .Type .KeyOpt == colKeyPrimary
532+
533+ if ! isPkey {
534+ OuterLoop:
535+ for _ , index := range indexes {
536+ if index .Info .Primary {
537+ for _ , indexCol := range index .Columns {
538+ if indexCol .Column .Equal (cd .Name ) {
539+ isPkey = true
540+ break OuterLoop
541+ }
542+ }
543+ }
544+ }
545+ }
546+
547+ return & sql.Column {
548+ Nullable : ! bool (typ .NotNull ),
549+ Type : internalTyp ,
550+ Name : cd .Name .String (),
551+ PrimaryKey : isPkey ,
552+ // TODO
553+ Default : nil ,
554+ }, nil
555+ }
556+
498557func columnsToStrings (cols sqlparser.Columns ) []string {
499558 res := make ([]string , len (cols ))
500559 for i , c := range cols {
0 commit comments