From ad615b938f01e3002533aa2581eef18e241e710b Mon Sep 17 00:00:00 2001 From: Kenny Flegal Date: Sun, 14 Aug 2022 11:52:51 -0700 Subject: [PATCH 1/5] allow nested compound statements wherever permitted --- src/arithmetic.rs | 7 +- src/common.rs | 91 ++++++++------ src/compound_select.rs | 203 ++++++++++++++++++++++-------- src/condition.rs | 49 ++++---- src/create.rs | 16 +-- src/delete.rs | 4 +- src/insert.rs | 4 +- src/join.rs | 8 +- src/order.rs | 8 +- src/parser.rs | 16 ++- src/select.rs | 145 +++++++++++++-------- tests/exists-queries.txt | 1 + tests/lib.rs | 10 +- tests/nested-compound-selects.txt | 5 + 14 files changed, 381 insertions(+), 186 deletions(-) create mode 100644 tests/nested-compound-selects.txt diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 470b3eb..90a3943 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -459,10 +459,13 @@ mod tests { } #[test] - fn arithmetic_scalar(){ + fn arithmetic_scalar() { let qs = "56"; let res = arithmetic(qs.as_bytes()); assert!(res.is_err()); - assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap()); + assert_eq!( + nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), + res.err().unwrap() + ); } } diff --git a/src/common.rs b/src/common.rs index 1d6cd4d..cdc52f6 100644 --- a/src/common.rs +++ b/src/common.rs @@ -388,7 +388,7 @@ where let (inp, _) = first.parse(inp)?; let (inp, o2) = second.parse(inp)?; third.parse(inp).map(|(i, _)| (i, o2)) - }, + } } } } @@ -641,7 +641,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> { // present. pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> { let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1))); - let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?; + let (remaining_input, (distinct, args)) = + tuple((distinct_parser, function_argument_parser))(i)?; Ok((remaining_input, (args, distinct.is_some()))) } @@ -695,12 +696,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> { FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep) }, ), - map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| { - let (name, _, _, arguments, _) = tuple; - FunctionExpression::Generic( - str::from_utf8(name).unwrap().to_string(), - FunctionArguments::from(arguments)) - }) + map( + tuple(( + sql_identifier, + multispace0, + tag("("), + separated_list0( + tag(","), + delimited(multispace0, function_argument_parser, multispace0), + ), + tag(")"), + )), + |tuple| { + let (name, _, _, arguments, _) = tuple; + FunctionExpression::Generic( + str::from_utf8(name).unwrap().to_string(), + FunctionArguments::from(arguments), + ) + }, + ), ))(i) } @@ -1021,22 +1035,23 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> { // Parse a reference to a named schema.table, with an optional alias pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> { map( - tuple(( - opt(pair(sql_identifier, tag("."))), - sql_identifier, - opt(as_alias) - )), - |tup| Table { - name: String::from(str::from_utf8(tup.1).unwrap()), - alias: match tup.2 { - Some(a) => Some(String::from(a)), - None => None, - }, - schema: match tup.0 { - Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), - None => None, + tuple(( + opt(pair(sql_identifier, tag("."))), + sql_identifier, + opt(as_alias), + )), + |tup| Table { + name: String::from(str::from_utf8(tup.1).unwrap()), + alias: match tup.2 { + Some(a) => Some(String::from(a)), + None => None, + }, + schema: match tup.0 { + Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), + None => None, + }, }, - })(i) + )(i) } // Parse a reference to a named table, with an optional alias @@ -1047,7 +1062,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> { Some(a) => Some(String::from(a)), None => None, }, - schema: None, + schema: None, })(i) } @@ -1137,25 +1152,31 @@ mod tests { name: String::from("max(addr_id)"), alias: None, table: None, - function: Some(Box::new(FunctionExpression::Max( - FunctionArgument::Column(Column::from("addr_id")), - ))), + function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column( + Column::from("addr_id"), + )))), }; assert_eq!(res.unwrap().1, expected); } #[test] fn simple_generic_function() { - let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()]; + let qlist = [ + "coalesce(a,b,c)".as_bytes(), + "coalesce (a,b,c)".as_bytes(), + "coalesce(a ,b,c)".as_bytes(), + "coalesce(a, b,c)".as_bytes(), + ]; for q in qlist.iter() { let res = column_function(q); - let expected = FunctionExpression::Generic("coalesce".to_string(), - FunctionArguments::from( - vec!( - FunctionArgument::Column(Column::from("a")), - FunctionArgument::Column(Column::from("b")), - FunctionArgument::Column(Column::from("c")) - ))); + let expected = FunctionExpression::Generic( + "coalesce".to_string(), + FunctionArguments::from(vec![ + FunctionArgument::Column(Column::from("a")), + FunctionArgument::Column(Column::from("b")), + FunctionArgument::Column(Column::from("c")), + ]), + ); assert_eq!(res, Ok((&b""[..], expected))); } } diff --git a/src/compound_select.rs b/src/compound_select.rs index d3ece89..70a5679 100644 --- a/src/compound_select.rs +++ b/src/compound_select.rs @@ -7,10 +7,10 @@ use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; use nom::multi::many1; -use nom::sequence::{delimited, preceded, tuple}; +use nom::sequence::{preceded, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; -use select::{limit_clause, nested_selection, LimitClause, SelectStatement}; +use select::{limit_clause, nested_simple_selection, LimitClause, Selection}; #[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] pub enum CompoundSelectOperator { @@ -33,7 +33,7 @@ impl fmt::Display for CompoundSelectOperator { #[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] pub struct CompoundSelectStatement { - pub selects: Vec<(Option, SelectStatement)>, + pub selects: Vec<(Option, Selection)>, pub order: Option, pub limit: Option, } @@ -89,43 +89,78 @@ fn compound_op(i: &[u8]) -> IResult<&[u8], CompoundSelectOperator> { ))(i) } -fn other_selects(i: &[u8]) -> IResult<&[u8], (Option, SelectStatement)> { - let (remaining_input, (_, op, _, select)) = tuple(( - multispace0, - compound_op, - multispace1, - opt_delimited( - tag("("), - delimited(multispace0, nested_selection, multispace0), - tag(")"), +// Parse terminated compound selection +pub fn compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { + let (remaining_input, (compound_selection, _, _)) = + tuple((nested_compound_selection, multispace0, statement_terminator))(i)?; + + Ok((remaining_input, compound_selection)) +} + +pub fn compound_selection_part(i: &[u8]) -> IResult<&[u8], Selection> { + alt(( + map(compound_selection_compound_part, |cs| cs.into()), + map( + opt_delimited(tag("("), nested_simple_selection, tag(")")), + |s| s.into(), ), + ))(i) +} + +pub fn compound_selection_compound_part(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { + let (remaining_input, (_, lhs, op_rhs, _)) = tuple(( + tag("("), + opt_delimited(tag("("), nested_simple_selection, tag(")")), + many1(tuple((multispace1, compound_op_selection_part))), + tag(")"), ))(i)?; - Ok((remaining_input, (Some(op), select))) + let mut css = CompoundSelectStatement { + selects: vec![], + order: None, + limit: None, + }; + + css.selects.push((None, lhs.into())); + + for (_, (op, rhs)) in op_rhs { + css.selects.push((Some(op), rhs.into())) + } + + Ok((remaining_input, css)) } -// Parse compound selection -pub fn compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { - let (remaining_input, (first_select, other_selects, _, order, limit, _)) = tuple(( - opt_delimited(tag("("), nested_selection, tag(")")), - many1(other_selects), - multispace0, +pub fn compound_op_selection_part(i: &[u8]) -> IResult<&[u8], (CompoundSelectOperator, Selection)> { + let (remaining_input, (op, _, selection)) = + tuple((compound_op, multispace1, compound_selection_part))(i)?; + + Ok((remaining_input, (op, selection))) +} + +// Parse nested compound selection +pub fn nested_compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { + let (remaining_input, ((first, other_selects), order, limit)) = tuple(( + tuple(( + compound_selection_part, + many1(tuple((multispace1, compound_op_selection_part))), + )), opt(order_clause), opt(limit_clause), - statement_terminator, ))(i)?; - let mut selects = vec![(None, first_select)]; - selects.extend(other_selects); - - Ok(( - remaining_input, - CompoundSelectStatement { - selects, - order, - limit, - }, - )) + let mut css = CompoundSelectStatement { + selects: vec![], + order, + limit, + }; + + css.selects.push((None, first.into())); + + for os in other_selects { + css.selects.push((Some(os.1 .0), os.1 .1.into())); + } + + Ok((remaining_input, css)) } #[cfg(test)] @@ -133,14 +168,16 @@ mod tests { use super::*; use column::Column; use common::{FieldDefinitionExpression, FieldValueExpression, Literal}; + use select::selection; use table::Table; + use SelectStatement; #[test] fn union() { let qstr = "SELECT id, 1 FROM Vote UNION SELECT id, stars from Rating;"; let qstr2 = "(SELECT id, 1 FROM Vote) UNION (SELECT id, stars from Rating);"; - let res = compound_selection(qstr.as_bytes()); - let res2 = compound_selection(qstr2.as_bytes()); + let res = selection(qstr.as_bytes()); + let res2 = selection(qstr2.as_bytes()); let first_select = SelectStatement { tables: vec![Table::from("Vote")], @@ -162,15 +199,18 @@ mod tests { }; let expected = CompoundSelectStatement { selects: vec![ - (None, first_select), - (Some(CompoundSelectOperator::DistinctUnion), second_select), + (None, first_select.into()), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.into(), + ), ], order: None, limit: None, }; - assert_eq!(res.unwrap().1, expected); - assert_eq!(res2.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.clone().into()); + assert_eq!(res2.unwrap().1, expected.into()); } #[test] @@ -185,29 +225,38 @@ mod tests { assert!(&res.is_err()); assert_eq!( res.unwrap_err(), - nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ");".as_bytes(), + nom::error::ErrorKind::MultiSpace + )) ); assert!(&res2.is_err()); assert_eq!( res2.unwrap_err(), - nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ";".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res3.is_err()); assert_eq!( res3.unwrap_err(), nom::Err::Error(nom::error::Error::new( ") UNION (SELECT id, stars from Rating;".as_bytes(), - nom::error::ErrorKind::Tag + nom::error::ErrorKind::MultiSpace )) ); } #[test] fn multi_union() { - let qstr = "SELECT id, 1 FROM Vote \ - UNION SELECT id, stars from Rating \ - UNION DISTINCT SELECT 42, 5 FROM Vote;"; - let res = compound_selection(qstr.as_bytes()); + let q = "SELECT id, 1 FROM Vote UNION SELECT id, stars from Rating UNION DISTINCT SELECT 42, 5 FROM Vote"; + let qstr0 = format!("{};", q); + let qstr1 = format!("({}) UNION ALL ({});", q, q); + let qstr2 = format!("{} UNION ALL {};", q, q); + let res0 = selection(qstr0.as_bytes()); + let res1 = selection(qstr1.as_bytes()); + let res2 = selection(qstr2.as_bytes()); let first_select = SelectStatement { tables: vec![Table::from("Vote")], @@ -240,23 +289,71 @@ mod tests { ..Default::default() }; - let expected = CompoundSelectStatement { + let expected0 = CompoundSelectStatement { + selects: vec![ + (None, first_select.clone().into()), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + third_select.clone().into(), + ), + ], + order: None, + limit: None, + }; + + let expected1 = CompoundSelectStatement { selects: vec![ - (None, first_select), - (Some(CompoundSelectOperator::DistinctUnion), second_select), - (Some(CompoundSelectOperator::DistinctUnion), third_select), + (None, expected0.clone().into()), + ( + Some(CompoundSelectOperator::Union), + expected0.clone().into(), + ), ], order: None, limit: None, }; - assert_eq!(res.unwrap().1, expected); + let expected2 = CompoundSelectStatement { + selects: vec![ + (None, first_select.clone().into()), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + third_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::Union), + first_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + third_select.into(), + ), + ], + order: None, + limit: None, + }; + + assert_eq!(res0.unwrap().1, expected0.into()); + assert_eq!(res1.unwrap().1, expected1.into()); + assert_eq!(res2.unwrap().1, expected2.into()); } #[test] fn union_all() { let qstr = "SELECT id, 1 FROM Vote UNION ALL SELECT id, stars from Rating;"; - let res = compound_selection(qstr.as_bytes()); + let res = selection(qstr.as_bytes()); let first_select = SelectStatement { tables: vec![Table::from("Vote")], @@ -278,13 +375,13 @@ mod tests { }; let expected = CompoundSelectStatement { selects: vec![ - (None, first_select), - (Some(CompoundSelectOperator::Union), second_select), + (None, first_select.into()), + (Some(CompoundSelectOperator::Union), second_select.into()), ], order: None, limit: None, }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } } diff --git a/src/condition.rs b/src/condition.rs index a210b7a..865c989 100644 --- a/src/condition.rs +++ b/src/condition.rs @@ -14,7 +14,7 @@ use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; use nom::IResult; -use select::{nested_selection, SelectStatement}; +use select::{nested_selection, nested_simple_selection, SelectStatement, Selection}; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum ConditionBase { @@ -89,7 +89,7 @@ pub enum ConditionExpression { ComparisonOp(ConditionTree), LogicalOp(ConditionTree), NegationOp(Box), - ExistsOp(Box), + ExistsOp(Box), Base(ConditionBase), Arithmetic(Box), Bracketed(Box), @@ -227,9 +227,10 @@ fn in_operation(i: &[u8]) -> IResult<&[u8], (Operator, ConditionExpression)> { opt(terminated(tag_no_case("not"), multispace1)), terminated(tag_no_case("in"), multispace0), alt(( - map(delimited(tag("("), nested_selection, tag(")")), |s| { - ConditionBase::NestedSelect(Box::new(s)) - }), + map( + delimited(tag("("), nested_simple_selection, tag(")")), + |s| ConditionBase::NestedSelect(Box::new(s)), + ), map(delimited(tag("("), value_list, tag(")")), |vs| { ConditionBase::LiteralList(vs) }), @@ -291,10 +292,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> { }, ); - alt(( - simple_expr, - nested_exists, - ))(i) + alt((simple_expr, nested_exists))(i) } fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { @@ -320,9 +318,10 @@ fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { map(column_identifier, |f| { ConditionExpression::Base(ConditionBase::Field(f)) }), - map(delimited(tag("("), nested_selection, tag(")")), |s| { - ConditionExpression::Base(ConditionBase::NestedSelect(Box::new(s))) - }), + map( + delimited(tag("("), nested_simple_selection, tag(")")), + |s| ConditionExpression::Base(ConditionBase::NestedSelect(Box::new(s))), + ), ))(i) } @@ -745,11 +744,14 @@ mod tests { let res = condition_expr(cond.as_bytes()); - let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], - fields: columns(&["col"]), - ..Default::default() - }); + let nested_select = Box::new( + SelectStatement { + tables: vec![Table::from("foo")], + fields: columns(&["col"]), + ..Default::default() + } + .into(), + ); let expected = ConditionExpression::ExistsOp(nested_select); @@ -766,11 +768,14 @@ mod tests { let res = condition_expr(cond.as_bytes()); - let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], - fields: columns(&["col"]), - ..Default::default() - }); + let nested_select = Box::new( + SelectStatement { + tables: vec![Table::from("foo")], + fields: columns(&["col"]), + ..Default::default() + } + .into(), + ); let expected = ConditionExpression::NegationOp(Box::new(ConditionExpression::ExistsOp(nested_select))); diff --git a/src/create.rs b/src/create.rs index a1b5afc..f3f11cf 100644 --- a/src/create.rs +++ b/src/create.rs @@ -5,8 +5,8 @@ use std::str::FromStr; use column::{Column, ColumnConstraint, ColumnSpecification}; use common::{ - column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator, - schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, + column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier, + statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, }; use compound_select::{compound_selection, CompoundSelectStatement}; use create_table_options::table_options; @@ -18,7 +18,7 @@ use nom::multi::{many0, many1}; use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_type, OrderType}; -use select::{nested_selection, SelectStatement}; +use select::{nested_simple_selection, SelectStatement}; use table::Table; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] @@ -431,7 +431,7 @@ pub fn view_creation(i: &[u8]) -> IResult<&[u8], CreateViewStatement> { multispace1, alt(( map(compound_selection, |s| SelectSpecification::Compound(s)), - map(nested_selection, |s| SelectSpecification::Simple(s)), + map(nested_simple_selection, |s| SelectSpecification::Simple(s)), )), statement_terminator, ))(i)?; @@ -534,7 +534,7 @@ mod tests { assert_eq!( res.unwrap().1, CreateTableStatement { - table: Table::from(("db1","t")), + table: Table::from(("db1", "t")), fields: vec![ColumnSpecification::new( Column::from("t.x"), SqlType::Int(32) @@ -833,7 +833,8 @@ mod tests { tables: vec![Table::from("users")], fields: vec![FieldDefinitionExpression::All], ..Default::default() - }, + } + .into(), ), ( Some(CompoundSelectOperator::DistinctUnion), @@ -841,7 +842,8 @@ mod tests { tables: vec![Table::from("old_users")], fields: vec![FieldDefinitionExpression::All], ..Default::default() - }, + } + .into(), ), ], order: None, diff --git a/src/delete.rs b/src/delete.rs index 64ee1dc..40f3cd1 100644 --- a/src/delete.rs +++ b/src/delete.rs @@ -1,7 +1,7 @@ use nom::character::complete::multispace1; use std::{fmt, str}; -use common::{statement_terminator, schema_table_reference}; +use common::{schema_table_reference, statement_terminator}; use condition::ConditionExpression; use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; @@ -77,7 +77,7 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), ..Default::default() } ); diff --git a/src/insert.rs b/src/insert.rs index 9f55107..5c81427 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,7 +4,7 @@ use std::str; use column::Column; use common::{ - assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list, + assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, ws_sep_comma, FieldValueExpression, Literal, }; use keywords::escape_if_keyword; @@ -145,7 +145,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() diff --git a/src/join.rs b/src/join.rs index b91f5dc..35e3638 100644 --- a/src/join.rs +++ b/src/join.rs @@ -7,7 +7,7 @@ use nom::branch::alt; use nom::bytes::complete::tag_no_case; use nom::combinator::map; use nom::IResult; -use select::{JoinClause, SelectStatement}; +use select::{JoinClause, Selection}; use table::Table; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] @@ -17,7 +17,7 @@ pub enum JoinRightSide { /// A comma-separated (and implicitly joined) sequence of tables. Tables(Vec), /// A nested selection, represented as (query, alias). - NestedSelect(Box, Option), + NestedSelect(Box, Option), /// A nested join clause. NestedJoin(Box), } @@ -111,14 +111,14 @@ mod tests { use condition::ConditionBase::*; use condition::ConditionExpression::{self, *}; use condition::ConditionTree; - use select::{selection, JoinClause, SelectStatement}; + use select::{simple_selection, JoinClause, SelectStatement}; #[test] fn inner_join() { let qstring = "SELECT tags.* FROM tags \ INNER JOIN taggings ON tags.id = taggings.tag_id"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let ct = ConditionTree { left: Box::new(Base(Field(Column::from("tags.id")))), diff --git a/src/order.rs b/src/order.rs index 3db9ddc..273908a 100644 --- a/src/order.rs +++ b/src/order.rs @@ -82,7 +82,7 @@ pub fn order_clause(i: &[u8]) -> IResult<&[u8], OrderClause> { #[cfg(test)] mod tests { use super::*; - use select::selection; + use select::simple_selection; #[test] fn order_clause() { @@ -103,9 +103,9 @@ mod tests { columns: vec![("name".into(), OrderType::OrderAscending)], }; - let res1 = selection(qstring1.as_bytes()); - let res2 = selection(qstring2.as_bytes()); - let res3 = selection(qstring3.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); + let res3 = simple_selection(qstring3.as_bytes()); assert_eq!(res1.unwrap().1.order, Some(expected_ord1)); assert_eq!(res2.unwrap().1.order, Some(expected_ord2)); assert_eq!(res3.unwrap().1.order, Some(expected_ord3)); diff --git a/src/parser.rs b/src/parser.rs index ae68592..a0dfc42 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,7 +1,7 @@ use std::fmt; use std::str; -use compound_select::{compound_selection, CompoundSelectStatement}; +use compound_select::CompoundSelectStatement; use create::{creation, view_creation, CreateTableStatement, CreateViewStatement}; use delete::{deletion, DeleteStatement}; use drop::{drop_table, DropTableStatement}; @@ -9,7 +9,7 @@ use insert::{insertion, InsertStatement}; use nom::branch::alt; use nom::combinator::map; use nom::IResult; -use select::{selection, SelectStatement}; +use select::{selection, SelectStatement, Selection}; use set::{set, SetStatement}; use update::{updating, UpdateStatement}; @@ -42,12 +42,20 @@ impl fmt::Display for SqlQuery { } } +impl From for SqlQuery { + fn from(s: Selection) -> Self { + match s { + Selection::Statement(ss) => SqlQuery::Select(ss), + Selection::Compound(css) => SqlQuery::CompoundSelect(css), + } + } +} + pub fn sql_query(i: &[u8]) -> IResult<&[u8], SqlQuery> { alt(( map(creation, |c| SqlQuery::CreateTable(c)), map(insertion, |i| SqlQuery::Insert(i)), - map(compound_selection, |cs| SqlQuery::CompoundSelect(cs)), - map(selection, |s| SqlQuery::Select(s)), + map(selection, |s| s.into()), map(deletion, |d| SqlQuery::Delete(d)), map(drop_table, |dt| SqlQuery::DropTable(dt)), map(updating, |u| SqlQuery::Update(u)), diff --git a/src/select.rs b/src/select.rs index 8735e95..cf53b61 100644 --- a/src/select.rs +++ b/src/select.rs @@ -1,5 +1,6 @@ use nom::character::complete::{multispace0, multispace1}; use std::fmt; +use std::fmt::{Display, Formatter}; use std::str; use column::Column; @@ -8,6 +9,7 @@ use common::{ as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference, unsigned_number, }; +use compound_select::nested_compound_selection; use condition::{condition_expr, ConditionExpression}; use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide}; use nom::branch::alt; @@ -18,6 +20,7 @@ use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; use table::Table; +use CompoundSelectStatement; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct GroupByClause { @@ -76,6 +79,33 @@ impl fmt::Display for LimitClause { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum Selection { + Statement(SelectStatement), + Compound(CompoundSelectStatement), +} + +impl Display for Selection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Statement(s) => write!(f, "{}", s), + Self::Compound(cs) => write!(f, "{}", cs), + } + } +} + +impl From for Selection { + fn from(ss: SelectStatement) -> Self { + Self::Statement(ss) + } +} + +impl From for Selection { + fn from(css: CompoundSelectStatement) -> Self { + Self::Compound(css) + } +} + #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct SelectStatement { pub tables: Vec
, @@ -268,12 +298,24 @@ pub fn where_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> { Ok((remaining_input, where_condition)) } -// Parse rule for a SQL selection query. -pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { - terminated(nested_selection, statement_terminator)(i) +pub fn selection(i: &[u8]) -> IResult<&[u8], Selection> { + terminated(nested_selection, opt(statement_terminator))(i) +} + +pub fn nested_selection(i: &[u8]) -> IResult<&[u8], Selection> { + alt(( + map(nested_compound_selection, |cs| Selection::Compound(cs)), + map(nested_simple_selection, |s| Selection::Statement(s)), + ))(i) +} + +#[cfg(test)] +// Parse rule for a simple SQL selection query, currently only used to simplify tests +pub fn simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { + terminated(nested_simple_selection, statement_terminator)(i) } -pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { +pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { let ( remaining_input, (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit), @@ -330,7 +372,7 @@ mod tests { fn simple_select() { let qstring = "SELECT id, name FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -345,7 +387,7 @@ mod tests { fn more_involved_select() { let qstring = "SELECT users.id, users.name FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -364,7 +406,7 @@ mod tests { // TODO: doesn't support selecting literals without a FROM clause, which is still valid SQL // let qstring = "SELECT NULL, 1, \"foo\";"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -392,7 +434,7 @@ mod tests { fn select_all() { let qstring = "SELECT * FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -407,7 +449,7 @@ mod tests { fn select_all_in_table() { let qstring = "SELECT users.* FROM users, votes;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -422,7 +464,7 @@ mod tests { fn spaces_optional() { let qstring = "SELECT id,name FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -439,8 +481,8 @@ mod tests { let qstring_uc = "SELECT id, name FROM users;"; assert_eq!( - selection(qstring_lc.as_bytes()).unwrap(), - selection(qstring_uc.as_bytes()).unwrap() + simple_selection(qstring_lc.as_bytes()).unwrap(), + simple_selection(qstring_uc.as_bytes()).unwrap() ); } @@ -450,9 +492,9 @@ mod tests { let qstring_nosem = "select id, name from users"; let qstring_linebreak = "select id, name from users\n"; - let r1 = selection(qstring_sem.as_bytes()).unwrap(); - let r2 = selection(qstring_nosem.as_bytes()).unwrap(); - let r3 = selection(qstring_linebreak.as_bytes()).unwrap(); + let r1 = simple_selection(qstring_sem.as_bytes()).unwrap(); + let r2 = simple_selection(qstring_nosem.as_bytes()).unwrap(); + let r3 = simple_selection(qstring_linebreak.as_bytes()).unwrap(); assert_eq!(r1, r2); assert_eq!(r2, r3); } @@ -482,7 +524,7 @@ mod tests { } fn where_clause_with_variable_placeholder(qstring: &str, literal: Literal) { - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_left = Base(Field(Column::from("email"))); let expected_where_cond = Some(ComparisonOp(ConditionTree { @@ -515,8 +557,8 @@ mod tests { offset: 10, }; - let res1 = selection(qstring1.as_bytes()); - let res2 = selection(qstring2.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); assert_eq!(res1.unwrap().1.limit, Some(expected_lim1)); assert_eq!(res2.unwrap().1.limit, Some(expected_lim2)); } @@ -526,14 +568,14 @@ mod tests { let qstring1 = "select * from PaperTag as t;"; // let qstring2 = "select * from PaperTag t;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: None, + schema: None, },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -547,14 +589,14 @@ mod tests { fn table_schema() { let qstring1 = "select * from db1.PaperTag as t;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: Some(String::from("db1")), + schema: Some(String::from("db1")), },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -569,7 +611,7 @@ mod tests { let qstring1 = "select name as TagName from PaperTag;"; let qstring2 = "select PaperTag.name as TagName from PaperTag;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { @@ -583,7 +625,7 @@ mod tests { ..Default::default() } ); - let res2 = selection(qstring2.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); assert_eq!( res2.unwrap().1, SelectStatement { @@ -604,7 +646,7 @@ mod tests { let qstring1 = "select name TagName from PaperTag;"; let qstring2 = "select PaperTag.name TagName from PaperTag;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { @@ -618,7 +660,7 @@ mod tests { ..Default::default() } ); - let res2 = selection(qstring2.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); assert_eq!( res2.unwrap().1, SelectStatement { @@ -638,7 +680,7 @@ mod tests { fn distinct() { let qstring = "select distinct tag from PaperTag where paperId=?;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_left = Base(Field(Column::from("paperId"))); let expected_where_cond = Some(ComparisonOp(ConditionTree { left: Box::new(expected_left), @@ -663,7 +705,7 @@ mod tests { fn simple_condition_expr() { let qstring = "select infoJson from PaperStorage where paperId=? and paperStorageId=?;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let left_ct = ConditionTree { left: Box::new(Base(Field(Column::from("paperId")))), @@ -700,7 +742,7 @@ mod tests { #[test] fn where_and_limit_clauses() { let qstring = "select * from users where id = ? limit 10\n"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_lim = Some(LimitClause { limit: 10, @@ -731,7 +773,7 @@ mod tests { fn aggregation_column() { let qstring = "SELECT max(addr_id) FROM address;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("addr_id"))); assert_eq!( res.unwrap().1, @@ -752,7 +794,7 @@ mod tests { fn aggregation_column_with_alias() { let qstring = "SELECT max(addr_id) AS max_addr FROM address;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("addr_id"))); let expected_stmt = SelectStatement { tables: vec![Table::from("address")], @@ -771,7 +813,7 @@ mod tests { fn count_all() { let qstring = "SELECT COUNT(*) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::CountStar; let expected_stmt = SelectStatement { tables: vec![Table::from("votes")], @@ -794,7 +836,7 @@ mod tests { fn count_distinct() { let qstring = "SELECT COUNT(DISTINCT vote_id) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Count(FunctionArgument::Column(Column::from("vote_id")), true); let expected_stmt = SelectStatement { @@ -818,7 +860,7 @@ mod tests { fn count_filter() { let qstring = "SELECT COUNT(CASE WHEN vote_id > 10 THEN vote_id END) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("vote_id")))), @@ -854,7 +896,7 @@ mod tests { fn sum_filter() { let qstring = "SELECT SUM(CASE WHEN sign = 1 THEN vote_id END) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("sign")))), @@ -891,7 +933,7 @@ mod tests { let qstring = "SELECT SUM(CASE WHEN sign = 1 THEN vote_id ELSE 6 END) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("sign")))), @@ -930,7 +972,7 @@ mod tests { FROM votes GROUP BY votes.comment_id;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = LogicalOp(ConditionTree { left: Box::new(ComparisonOp(ConditionTree { @@ -974,7 +1016,7 @@ mod tests { fn generic_function_query() { let qstring = "SELECT coalesce(a, b,c) as x,d FROM sometable;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Generic( String::from("coalesce"), FunctionArguments { @@ -1026,7 +1068,7 @@ mod tests { let qstring = "SELECT * FROM item, author WHERE item.i_a_id = author.a_id AND \ item.i_subject = ? ORDER BY item.i_title limit 50;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_where_cond = Some(LogicalOp(ConditionTree { left: Box::new(ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("item.i_a_id")))), @@ -1064,7 +1106,7 @@ mod tests { fn simple_joins() { let qstring = "select paperId from PaperConflict join PCMember using (contactId);"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_stmt = SelectStatement { tables: vec![Table::from("PaperConflict")], fields: columns(&["paperId"]), @@ -1086,7 +1128,7 @@ mod tests { join PaperReview on (PCMember.contactId=PaperReview.contactId) \ order by contactId;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let ct = ConditionTree { left: Box::new(Base(Field(Column::from("PCMember.contactId")))), right: Box::new(Base(Field(Column::from("PaperReview.contactId")))), @@ -1113,7 +1155,7 @@ mod tests { from PCMember \ join PaperReview on PCMember.contactId=PaperReview.contactId \ order by contactId;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!(res.unwrap().1, expected); } @@ -1133,7 +1175,7 @@ mod tests { (contactId) left join ChairAssistant using (contactId) left join Chair \ using (contactId) where ContactInfo.contactId=?;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let ct = ConditionTree { left: Box::new(Base(Field(Column::from("ContactInfo.contactId")))), right: Box::new(Base(Literal(Literal::Placeholder( @@ -1177,7 +1219,7 @@ mod tests { WHERE orders.o_c_id IN (SELECT o_c_id FROM orders, order_line \ WHERE orders.o_id = order_line.ol_o_id);"; - let res = selection(qstr.as_bytes()); + let res = simple_selection(qstr.as_bytes()); let inner_where_clause = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("orders.o_id")))), right: Box::new(Base(Field(Column::from("order_line.ol_o_id")))), @@ -1214,7 +1256,7 @@ mod tests { WHERE orders.o_id = order_line.ol_o_id \ AND orders.o_id > (SELECT MAX(o_id) FROM orders));"; - let res = selection(qstr.as_bytes()); + let res = simple_selection(qstr.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("o_id"))); let recursive_select = SelectStatement { @@ -1286,7 +1328,7 @@ mod tests { let qstr_with_alias = "SELECT o_id, ol_i_id FROM orders JOIN \ (SELECT ol_i_id FROM order_line) AS ids \ ON (orders.o_id = ids.ol_i_id);"; - let res = selection(qstr_with_alias.as_bytes()); + let res = simple_selection(qstr_with_alias.as_bytes()); // N.B.: Don't alias the inner select to `inner`, which is, well, a SQL keyword! let inner_select = SelectStatement { @@ -1300,7 +1342,10 @@ mod tests { fields: columns(&["o_id", "ol_i_id"]), join: vec![JoinClause { operator: JoinOperator::Join, - right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())), + right: JoinRightSide::NestedSelect( + Box::new(inner_select.into()), + Some("ids".into()), + ), constraint: JoinConstraint::On(ComparisonOp(ConditionTree { operator: Operator::Equal, left: Box::new(Base(Field(Column::from("orders.o_id")))), @@ -1340,7 +1385,7 @@ mod tests { ..Default::default() }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } #[test] @@ -1370,7 +1415,7 @@ mod tests { ..Default::default() }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } #[test] @@ -1407,6 +1452,6 @@ mod tests { ..Default::default() }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } } diff --git a/tests/exists-queries.txt b/tests/exists-queries.txt index 2180af6..8513bd3 100644 --- a/tests/exists-queries.txt +++ b/tests/exists-queries.txt @@ -3,4 +3,5 @@ SELECT * FROM employees e WHERE exists(SELECT id FROM eotm_dyn d WHERE d.employe SELECT * FROM employees e WHERE not exists ( SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id) SELECT * FROM employees e WHERE not (exists ( SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id)) SELECT * FROM employees e WHERE x > 3 and not exists (SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id ) and y < 3 +SELECT * FROM employees e WHERE x > 3 and not exists (SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id UNION SELECT id FROM eotm_dyn d WHERE d.employeeID IS NULL ) and y < 3 diff --git a/tests/lib.rs b/tests/lib.rs index e0d7263..f22e4b4 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -131,7 +131,7 @@ fn exists_test_queries() { ); assert!(res.is_ok()); // There are 4 queries - assert_eq!(res.unwrap(), 4); + assert_eq!(res.unwrap(), 5); } #[test] @@ -201,6 +201,14 @@ fn parse_comments() { assert_eq!(fail, 0); } +#[test] +fn parse_nested_compound_selects() { + let (ok, fail) = parse_file("tests/nested-compound-selects.txt"); + + assert_eq!(ok, 4); + assert_eq!(fail, 0); +} + #[test] fn parse_autoincrement() { let (ok, fail) = parse_file("tests/autoincrement.txt"); diff --git a/tests/nested-compound-selects.txt b/tests/nested-compound-selects.txt new file mode 100644 index 0000000..888f2b6 --- /dev/null +++ b/tests/nested-compound-selects.txt @@ -0,0 +1,5 @@ +SELECT a, b FROM table1 JOIN (SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table 3 WHERE a = b); +SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b UNION SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b; +(SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b UNION SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b) UNION ALL (SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b UNION SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b); +SELECT a, b FROM table1 WHERE a IN (SELECT c FROM table2 WHERE c = d UNION SELECT b FROM table3 WHERE a > b); + From 42f11ce4f5b1713833bcec668632b14d76a86e4f Mon Sep 17 00:00:00 2001 From: Kenny Flegal Date: Wed, 17 Aug 2022 16:12:40 -0700 Subject: [PATCH 2/5] parse ROW() methods for INSERT VALUES --- src/insert.rs | 217 ++++++++++++++++++++++++++++++++++++-------------- src/parser.rs | 3 +- 2 files changed, 161 insertions(+), 59 deletions(-) diff --git a/src/insert.rs b/src/insert.rs index 5c81427..28688e6 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -1,5 +1,6 @@ use nom::character::complete::{multispace0, multispace1}; use std::fmt; +use std::fmt::{Display, Formatter}; use std::str; use column::Column; @@ -8,18 +9,19 @@ use common::{ ws_sep_comma, FieldValueExpression, Literal, }; use keywords::escape_if_keyword; +use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; -use nom::combinator::opt; +use nom::combinator::{map, opt}; use nom::multi::many1; use nom::sequence::{delimited, preceded, tuple}; -use nom::IResult; +use nom::{IResult}; use table::Table; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct InsertStatement { pub table: Table, pub fields: Option>, - pub data: Vec>, + pub data: InsertData, pub ignore: bool, pub on_duplicate: Option>, } @@ -38,25 +40,51 @@ impl fmt::Display for InsertStatement { .join(", ") )?; } - write!( - f, - " VALUES {}", - self.data - .iter() - .map(|datas| format!( - "({})", - datas - .into_iter() - .map(|l| l.to_string()) - .collect::>() - .join(", ") - )) - .collect::>() - .join(", ") - ) + write!(f, " VALUES {}", self.data,) } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum InsertData { + RowValueList(Vec>), + ValueList(Vec>), +} + +impl Default for InsertData { + fn default() -> Self { + Self::ValueList(vec![vec![]]) + } +} + +impl Display for InsertData { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let fmt_values = |start, vvl: &Vec>| { + vvl.iter().fold(String::new(), |mut acc, vl| { + if acc.len() > 0 { + acc.push_str(", "); + } + acc.push_str(&format!("{}(", start)); + acc.push_str( + &vl.iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + ); + acc.push_str(")"); + acc + }) + }; + + match self { + Self::ValueList(vl) => { + write!(f, "{}", fmt_values("", vl)) + } + Self::RowValueList(vl) => { + write!(f, "{}", fmt_values("ROW", vl)) + } + } + } +} fn fields(i: &[u8]) -> IResult<&[u8], Vec> { delimited( preceded(tag("("), multispace0), @@ -65,8 +93,32 @@ fn fields(i: &[u8]) -> IResult<&[u8], Vec> { )(i) } -fn data(i: &[u8]) -> IResult<&[u8], Vec> { - delimited(tag("("), value_list, preceded(tag(")"), opt(ws_sep_comma)))(i) +fn data(i: &[u8]) -> IResult<&[u8], InsertData> { + alt(( + // The ROW() method can't be mixed and matched so we match one or the other sets + map( + tuple(( + multispace1, + many1(delimited( + tag("ROW("), + value_list, + preceded(tag(")"), opt(ws_sep_comma)), + )), + )), + |(_, vl)| InsertData::RowValueList(vl), + ), + map( + tuple(( + multispace0, + many1(delimited( + tag("("), + value_list, + preceded(tag(")"), opt(ws_sep_comma)), + )), + )), + |(_, vl)| InsertData::ValueList(vl), + ), + ))(i) } fn on_duplicate(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> { @@ -82,7 +134,7 @@ fn on_duplicate(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> // Parse rule for a SQL insert query. // TODO(malte): support REPLACE, nested selection, DEFAULT VALUES pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { - let (remaining_input, (_, ignore_res, _, _, _, table, _, fields, _, _, data, on_duplicate, _)) = + let (remaining_input, (_, ignore_res, _, _, _, table, _, fields, _, data, on_duplicate, _)) = tuple(( tag_no_case("insert"), opt(preceded(multispace1, tag_no_case("ignore"))), @@ -93,8 +145,7 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { multispace0, opt(fields), tag_no_case("values"), - multispace0, - many1(data), + data, opt(on_duplicate), statement_terminator, ))(i)?; @@ -123,15 +174,26 @@ mod tests { #[test] fn simple_insert() { - let qstring = "INSERT INTO users VALUES (42, \"test\");"; + let qstring0 = "INSERT INTO users VALUES (42, \"test\");"; + let qstring1 = "INSERT INTO users VALUES ROW(42, \"test\");"; - let res = insertion(qstring.as_bytes()); + let res0 = insertion(qstring0.as_bytes()); + let res1 = insertion(qstring1.as_bytes()); assert_eq!( - res.unwrap().1, + res0.unwrap().1, InsertStatement { table: Table::from("users"), fields: None, - data: vec![vec![42.into(), "test".into()]], + data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + ..Default::default() + } + ); + assert_eq!( + res1.unwrap().1, + InsertStatement { + table: Table::from("users"), + fields: None, + data: InsertData::RowValueList(vec![vec![42.into(), "test".into()]]), ..Default::default() } ); @@ -147,7 +209,7 @@ mod tests { InsertStatement { table: Table::from(("db1", "users")), fields: None, - data: vec![vec![42.into(), "test".into()]], + data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), ..Default::default() } ); @@ -163,12 +225,12 @@ mod tests { InsertStatement { table: Table::from("users"), fields: None, - data: vec![vec![ + data: InsertData::ValueList(vec![vec![ 42.into(), "test".into(), "test".into(), Literal::CurrentTimestamp, - ],], + ],]), ..Default::default() } ); @@ -184,7 +246,7 @@ mod tests { InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: vec![vec![42.into(), "test".into()]], + data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), ..Default::default() } ); @@ -193,37 +255,65 @@ mod tests { // Issue #3 #[test] fn insert_without_spaces() { - let qstring = "INSERT INTO users(id, name) VALUES(42, \"test\");"; + let qstring0 = "INSERT INTO users(id, name) VALUES(42, \"test\");"; + let qstring1 = "INSERT INTO users(id, name) VALUESROW(42, \"test\");"; - let res = insertion(qstring.as_bytes()); + let res0 = insertion(qstring0.as_bytes()); + let res1 = insertion(qstring1.as_bytes()); assert_eq!( - res.unwrap().1, + res0.unwrap().1, InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: vec![vec![42.into(), "test".into()]], + data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), ..Default::default() } ); + assert!(res1.is_err()); } #[test] fn multi_insert() { - let qstring = "INSERT INTO users (id, name) VALUES (42, \"test\"),(21, \"test2\");"; + let qstring0 = "INSERT INTO users (id, name) VALUES (42, \"test\"),(21, \"test2\");"; + let qstring1 = "INSERT INTO users (id, name) VALUES (42, \"test\"), (21, \"test2\");"; + let qstring2 = "INSERT INTO users (id, name) VALUES ROW(42, \"test\"),ROW(21, \"test2\");"; + let qstring3 = "INSERT INTO users (id, name) VALUES ROW(42, \"test\"), ROW(21, \"test2\");"; + let qstring4 = "INSERT INTO users (id, name) VALUES ROW(42, \"test\"),(21, \"test2\");"; + let qstring5 = "INSERT INTO users (id, name) VALUES (42, \"test\"),ROW(21, \"test2\");"; - let res = insertion(qstring.as_bytes()); - assert_eq!( - res.unwrap().1, - InsertStatement { - table: Table::from("users"), - fields: Some(vec![Column::from("id"), Column::from("name")]), - data: vec![ - vec![42.into(), "test".into()], - vec![21.into(), "test2".into()], - ], - ..Default::default() - } - ); + let res0 = insertion(qstring0.as_bytes()); + let res1 = insertion(qstring1.as_bytes()); + let res2 = insertion(qstring2.as_bytes()); + let res3 = insertion(qstring3.as_bytes()); + let res4 = insertion(qstring4.as_bytes()); + let res5 = insertion(qstring5.as_bytes()); + + let expected0 = InsertStatement { + table: Table::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: InsertData::ValueList(vec![ + vec![42.into(), "test".into()], + vec![21.into(), "test2".into()], + ]), + ..Default::default() + }; + let expected1 = expected0.clone(); + let expected2 = InsertStatement { + table: Table::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: InsertData::RowValueList(vec![ + vec![42.into(), "test".into()], + vec![21.into(), "test2".into()], + ]), + ..Default::default() + }; + let expected3 = expected2.clone(); + assert_eq!(res0.unwrap().1, expected0,); + assert_eq!(res1.unwrap().1, expected1,); + assert_eq!(res2.unwrap().1, expected2,); + assert_eq!(res3.unwrap().1, expected3,); + assert!(res4.is_err()); + assert!(res5.is_err()); } #[test] @@ -236,10 +326,10 @@ mod tests { InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: vec![vec![ + data: InsertData::ValueList(vec![vec![ Literal::Placeholder(ItemPlaceholder::QuestionMark), Literal::Placeholder(ItemPlaceholder::QuestionMark) - ]], + ]]), ..Default::default() } ); @@ -262,10 +352,10 @@ mod tests { InsertStatement { table: Table::from("keystores"), fields: Some(vec![Column::from("key"), Column::from("value")]), - data: vec![vec![ + data: InsertData::ValueList(vec![vec![ Literal::Placeholder(ItemPlaceholder::DollarNumber(1)), Literal::Placeholder(ItemPlaceholder::ColonNumber(2)) - ]], + ]]), on_duplicate: Some(vec![( Column::from("value"), FieldValueExpression::Arithmetic(expected_ae), @@ -277,15 +367,26 @@ mod tests { #[test] fn insert_with_leading_value_whitespace() { - let qstring = "INSERT INTO users (id, name) VALUES ( 42, \"test\");"; + let qstring0 = "INSERT INTO users (id, name) VALUES ( 42, \"test\");"; + let qstring1 = "INSERT INTO users (id, name) VALUES ROW( 42, \"test\");"; - let res = insertion(qstring.as_bytes()); + let res0 = insertion(qstring0.as_bytes()); + let res1 = insertion(qstring1.as_bytes()); assert_eq!( - res.unwrap().1, + res0.unwrap().1, + InsertStatement { + table: Table::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + ..Default::default() + } + ); + assert_eq!( + res1.unwrap().1, InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: vec![vec![42.into(), "test".into()]], + data: InsertData::RowValueList(vec![vec![42.into(), "test".into()]]), ..Default::default() } ); diff --git a/src/parser.rs b/src/parser.rs index a0dfc42..c24429a 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -84,6 +84,7 @@ where #[cfg(test)] mod tests { use super::*; + use insert::InsertData; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -98,7 +99,7 @@ mod tests { let expected = SqlQuery::Insert(InsertStatement { table: Table::from("users"), fields: None, - data: vec![vec![42.into(), "test".into()]], + data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), ..Default::default() }); let mut h0 = DefaultHasher::new(); From 7ad41ecd72ff4d8686fe38db6d7d697328b83f34 Mon Sep 17 00:00:00 2001 From: Kenny Flegal Date: Thu, 18 Aug 2022 12:14:25 -0700 Subject: [PATCH 3/5] Parse INSERT...SELECT style statements --- src/insert.rs | 78 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/src/insert.rs b/src/insert.rs index 28688e6..c235337 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -15,6 +15,7 @@ use nom::combinator::{map, opt}; use nom::multi::many1; use nom::sequence::{delimited, preceded, tuple}; use nom::{IResult}; +use select::{nested_selection, Selection}; use table::Table; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] @@ -48,6 +49,7 @@ impl fmt::Display for InsertStatement { pub enum InsertData { RowValueList(Vec>), ValueList(Vec>), + Select(Selection), } impl Default for InsertData { @@ -82,6 +84,9 @@ impl Display for InsertData { Self::RowValueList(vl) => { write!(f, "{}", fmt_values("ROW", vl)) } + Self::Select(s) => { + write!(f, "{}", s) + } } } } @@ -134,7 +139,7 @@ fn on_duplicate(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> // Parse rule for a SQL insert query. // TODO(malte): support REPLACE, nested selection, DEFAULT VALUES pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { - let (remaining_input, (_, ignore_res, _, _, _, table, _, fields, _, data, on_duplicate, _)) = + let (remaining_input, (_, ignore_res, _, _, _, table, _, fields, data, on_duplicate, _)) = tuple(( tag_no_case("insert"), opt(preceded(multispace1, tag_no_case("ignore"))), @@ -144,8 +149,7 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { schema_table_reference, multispace0, opt(fields), - tag_no_case("values"), - data, + insertion_values, opt(on_duplicate), statement_terminator, ))(i)?; @@ -164,6 +168,13 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { )) } +pub fn insertion_values(i: &[u8]) -> IResult<&[u8], InsertData> { + alt(( + map(nested_selection, |ns| InsertData::Select(ns)), + map(tuple((tag_no_case("values"), data)), |(_, d)| d), + ))(i) +} + #[cfg(test)] mod tests { use super::*; @@ -171,6 +182,8 @@ mod tests { use column::Column; use common::ItemPlaceholder; use table::Table; + use FieldDefinitionExpression::Col; + use {LiteralExpression, SelectStatement}; #[test] fn simple_insert() { @@ -391,4 +404,63 @@ mod tests { } ); } + + #[test] + fn insert_select() { + let qstring0 = "INSERT INTO users (id, name) SELECT id, name FROM dual;"; + let qstring1 = "INSERT INTO users (id, name) SELECT id, name FROM dual ON DUPLICATE KEY UPDATE name = 'dupe';"; + + let res0 = insertion(qstring0.as_bytes()); + let res1 = insertion(qstring1.as_bytes()); + + let expected0 = InsertStatement { + table: Table::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: InsertData::Select(Selection::Statement(SelectStatement { + tables: vec![Table { + name: "dual".to_string(), + alias: None, + schema: None, + }], + distinct: false, + fields: vec![ + Col(Column { + name: "id".to_string(), + alias: None, + table: None, + function: None, + }), + Col(Column { + name: "name".to_string(), + alias: None, + table: None, + function: None, + }), + ], + join: vec![], + where_clause: None, + group_by: None, + order: None, + limit: None, + })), + ..Default::default() + }; + + let mut expected1 = expected0.clone(); + expected1.on_duplicate = Some(vec![( + Column { + name: "name".to_string(), + alias: None, + table: None, + function: None, + }, + FieldValueExpression::Literal(LiteralExpression { + value: Literal::String("dupe".to_string()), + alias: None, + }), + )]); + + assert_eq!(res0.unwrap().1, expected0); + assert_eq!(res1.unwrap().1, expected1); + } } From cc00bb6af796460d8a0fdc34b1b82f0fd599afe0 Mon Sep 17 00:00:00 2001 From: Kenny Flegal Date: Fri, 19 Aug 2022 12:52:09 -0700 Subject: [PATCH 4/5] Support column names in ON DUPLICATE of INSERT --- src/common.rs | 56 +++++++++++++++++++++++++++++++++++++++++---------- src/insert.rs | 44 ++++++++++++++++++++++++++++------------ src/update.rs | 44 +++++++++++++++++++++++++++++----------- 3 files changed, 108 insertions(+), 36 deletions(-) diff --git a/src/common.rs b/src/common.rs index cdc52f6..1f41000 100644 --- a/src/common.rs +++ b/src/common.rs @@ -2,7 +2,7 @@ use nom::branch::alt; use nom::character::complete::{alphanumeric1, digit1, line_ending, multispace0, multispace1}; use nom::character::is_alphanumeric; use nom::combinator::{map, not, peek}; -use nom::{IResult, InputLength, Parser}; +use nom::{IResult, InputLength, Parser, Err}; use std::fmt::{self, Display}; use std::str; use std::str::FromStr; @@ -13,7 +13,7 @@ use column::{Column, FunctionArgument, FunctionArguments, FunctionExpression}; use keywords::{escape_if_keyword, sql_keyword}; use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1}; use nom::combinator::opt; -use nom::error::{ErrorKind, ParseError}; +use nom::error::{Error, ErrorKind, ParseError}; use nom::multi::{fold_many0, many0, many1, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; use table::Table; @@ -354,6 +354,33 @@ impl Display for FieldValueExpression { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum FieldAssignmentValue { + Col(Column), + Expression(FieldValueExpression), +} + +impl Display for FieldAssignmentValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Self::Col(ref col) => write!(f, "{}", col), + Self::Expression(ref expr) => write!(f, "{}", expr), + } + } +} + +impl From for FieldAssignmentValue { + fn from(c: Column) -> Self { + Self::Col(c) + } +} + +impl From for FieldAssignmentValue { + fn from(fve: FieldValueExpression) -> Self { + Self::Expression(fve) + } +} + #[inline] pub fn is_sql_identifier(chr: u8) -> bool { is_alphanumeric(chr) || chr == '_' as u8 || chr == '@' as u8 @@ -778,11 +805,17 @@ pub fn column_identifier(i: &[u8]) -> IResult<&[u8], Column> { // Parses a SQL identifier (alphanumeric1 and "_"). pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> { - alt(( + let (i, si) = alt(( preceded(not(peek(sql_keyword)), take_while1(is_sql_identifier)), delimited(tag("`"), take_while1(is_sql_identifier), tag("`")), delimited(tag("["), take_while1(is_sql_identifier), tag("]")), - ))(i) + ))(i)?; + + if str::from_utf8(si).unwrap_or("0").parse::().is_ok() { + return Err(Err::Error(Error { input: i, code: ErrorKind::IsA })); + } + + Ok((i, si)) } // Parse an unsigned integer. @@ -836,21 +869,22 @@ pub fn as_alias(i: &[u8]) -> IResult<&[u8], &str> { )(i) } -fn field_value_expr(i: &[u8]) -> IResult<&[u8], FieldValueExpression> { +fn field_value_expr(i: &[u8]) -> IResult<&[u8], FieldAssignmentValue> { alt(( + map(arithmetic_expression, |ae| { + FieldValueExpression::Arithmetic(ae).into() + }), + map(column_identifier, |c| c.into()), map(literal, |l| { FieldValueExpression::Literal(LiteralExpression { value: l.into(), alias: None, - }) - }), - map(arithmetic_expression, |ae| { - FieldValueExpression::Arithmetic(ae) + }).into() }), ))(i) } -fn assignment_expr(i: &[u8]) -> IResult<&[u8], (Column, FieldValueExpression)> { +fn assignment_expr(i: &[u8]) -> IResult<&[u8], (Column, FieldAssignmentValue)> { separated_pair( column_identifier_no_alias, delimited(multispace0, tag("="), multispace0), @@ -872,7 +906,7 @@ where delimited(multispace0, tag("="), multispace0)(i) } -pub fn assignment_expr_list(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> { +pub fn assignment_expr_list(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldAssignmentValue)>> { many1(terminated(assignment_expr, opt(ws_sep_comma)))(i) } diff --git a/src/insert.rs b/src/insert.rs index c235337..f5ad9ae 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,10 +4,7 @@ use std::fmt::{Display, Formatter}; use std::str; use column::Column; -use common::{ - assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, - ws_sep_comma, FieldValueExpression, Literal, -}; +use common::{assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, ws_sep_comma, Literal, FieldAssignmentValue}; use keywords::escape_if_keyword; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; @@ -24,7 +21,7 @@ pub struct InsertStatement { pub fields: Option>, pub data: InsertData, pub ignore: bool, - pub on_duplicate: Option>, + pub on_duplicate: Option>, } impl fmt::Display for InsertStatement { @@ -90,6 +87,7 @@ impl Display for InsertData { } } } + fn fields(i: &[u8]) -> IResult<&[u8], Vec> { delimited( preceded(tag("("), multispace0), @@ -126,7 +124,7 @@ fn data(i: &[u8]) -> IResult<&[u8], InsertData> { ))(i) } -fn on_duplicate(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> { +fn on_duplicate(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldAssignmentValue)>> { preceded( multispace0, preceded( @@ -184,6 +182,7 @@ mod tests { use table::Table; use FieldDefinitionExpression::Col; use {LiteralExpression, SelectStatement}; + use FieldValueExpression; #[test] fn simple_insert() { @@ -350,18 +349,21 @@ mod tests { #[test] fn insert_with_on_dup_update() { - let qstring = "INSERT INTO keystores (`key`, `value`) VALUES ($1, :2) \ - ON DUPLICATE KEY UPDATE `value` = `value` + 1"; + let qstring0 = "insert into keystores (`key`, `value`) values ($1, :2) \ + on duplicate key update `value` = `value` + 1"; + let qstring1 = "insert into keystores (`key`, `value`) values ($1, :2) \ + on duplicate key update value = value"; - let res = insertion(qstring.as_bytes()); - let expected_ae = ArithmeticExpression::new( + let res0 = insertion(qstring0.as_bytes()); + let res1 = insertion(qstring1.as_bytes()); + let expected_ae0 = ArithmeticExpression::new( ArithmeticOperator::Add, ArithmeticBase::Column(Column::from("value")), ArithmeticBase::Scalar(1.into()), None, ); assert_eq!( - res.unwrap().1, + res0.unwrap().1, InsertStatement { table: Table::from("keystores"), fields: Some(vec![Column::from("key"), Column::from("value")]), @@ -371,11 +373,27 @@ mod tests { ]]), on_duplicate: Some(vec![( Column::from("value"), - FieldValueExpression::Arithmetic(expected_ae), + FieldValueExpression::Arithmetic(expected_ae0).into(), ),]), ..Default::default() } ); + assert_eq!( + res1.unwrap().1, + InsertStatement { + table: Table::from("keystores"), + fields: Some(vec![Column::from("key"), Column::from("value")]), + data: InsertData::ValueList(vec![vec![ + Literal::Placeholder(ItemPlaceholder::DollarNumber(1)), + Literal::Placeholder(ItemPlaceholder::ColonNumber(2)) + ]]), + on_duplicate: Some(vec![( + Column::from("value"), + Column::from("value").into(), + ),]), + ..Default::default() + } + ); } #[test] @@ -457,7 +475,7 @@ mod tests { FieldValueExpression::Literal(LiteralExpression { value: Literal::String("dupe".to_string()), alias: None, - }), + }).into(), )]); assert_eq!(res0.unwrap().1, expected0); diff --git a/src/update.rs b/src/update.rs index 7366676..101bc0a 100644 --- a/src/update.rs +++ b/src/update.rs @@ -2,7 +2,7 @@ use nom::character::complete::{multispace0, multispace1}; use std::{fmt, str}; use column::Column; -use common::{assignment_expr_list, statement_terminator, table_reference, FieldValueExpression}; +use common::{assignment_expr_list, statement_terminator, table_reference, FieldAssignmentValue}; use condition::ConditionExpression; use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; @@ -15,7 +15,7 @@ use table::Table; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct UpdateStatement { pub table: Table, - pub fields: Vec<(Column, FieldValueExpression)>, + pub fields: Vec<(Column, FieldAssignmentValue)>, pub where_clause: Option, } @@ -72,27 +72,47 @@ mod tests { use condition::ConditionBase::*; use condition::ConditionExpression::*; use condition::ConditionTree; + use FieldValueExpression; use table::Table; #[test] fn simple_update() { - let qstring = "UPDATE users SET id = 42, name = 'test'"; + let qstring0 = "UPDATE users SET id = 42, name = 'test'"; + let qstring1 = "UPDATE users SET id = new_id, name = old_name"; - let res = updating(qstring.as_bytes()); + let res0 = updating(qstring0.as_bytes()); + let res1 = updating(qstring1.as_bytes()); assert_eq!( - res.unwrap().1, + res0.unwrap().1, UpdateStatement { table: Table::from("users"), fields: vec![ ( Column::from("id"), - FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))), + FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))).into(), ), ( Column::from("name"), FieldValueExpression::Literal(LiteralExpression::from(Literal::from( "test", - ))), + ))).into(), + ), + ], + ..Default::default() + } + ); + assert_eq!( + res1.unwrap().1, + UpdateStatement { + table: Table::from("users"), + fields: vec![ + ( + Column::from("id"), + Column::from("new_id").into(), + ), + ( + Column::from("name"), + Column::from("old_name").into(), ), ], ..Default::default() @@ -118,13 +138,13 @@ mod tests { fields: vec![ ( Column::from("id"), - FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))), + FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))).into(), ), ( Column::from("name"), FieldValueExpression::Literal(LiteralExpression::from(Literal::from( "test", - ))), + ))).into(), ), ], where_clause: expected_where_cond, @@ -165,7 +185,7 @@ mod tests { integral: -19216, fractional: 5479744, } - ),)), + ),)).into(), ),], where_clause: expected_where_cond, ..Default::default() @@ -197,7 +217,7 @@ mod tests { table: Table::from("users"), fields: vec![( Column::from("karma"), - FieldValueExpression::Arithmetic(expected_ae), + FieldValueExpression::Arithmetic(expected_ae).into(), ),], where_clause: expected_where_cond, ..Default::default() @@ -222,7 +242,7 @@ mod tests { table: Table::from("users"), fields: vec![( Column::from("karma"), - FieldValueExpression::Arithmetic(expected_ae), + FieldValueExpression::Arithmetic(expected_ae).into(), ),], ..Default::default() } From a052997c298d5dabbc17acbf21b7ee4ee8b10e34 Mon Sep 17 00:00:00 2001 From: Kenny Flegal Date: Fri, 19 Aug 2022 14:58:31 -0700 Subject: [PATCH 5/5] support for DEFAULT/DEFAULT(col) for INSERTs --- src/common.rs | 33 ++++++++- src/insert.rs | 193 +++++++++++++++++++++++++++++++++++++++++--------- src/parser.rs | 7 +- src/update.rs | 27 ++++--- 4 files changed, 207 insertions(+), 53 deletions(-) diff --git a/src/common.rs b/src/common.rs index 1f41000..6a499ae 100644 --- a/src/common.rs +++ b/src/common.rs @@ -2,7 +2,7 @@ use nom::branch::alt; use nom::character::complete::{alphanumeric1, digit1, line_ending, multispace0, multispace1}; use nom::character::is_alphanumeric; use nom::combinator::{map, not, peek}; -use nom::{IResult, InputLength, Parser, Err}; +use nom::{Err, IResult, InputLength, Parser}; use std::fmt::{self, Display}; use std::str; use std::str::FromStr; @@ -10,6 +10,7 @@ use std::str::FromStr; use arithmetic::{arithmetic_expression, ArithmeticExpression}; use case::case_when_column; use column::{Column, FunctionArgument, FunctionArguments, FunctionExpression}; +use insert::InsertDataValue; use keywords::{escape_if_keyword, sql_keyword}; use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1}; use nom::combinator::opt; @@ -812,7 +813,10 @@ pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> { ))(i)?; if str::from_utf8(si).unwrap_or("0").parse::().is_ok() { - return Err(Err::Error(Error { input: i, code: ErrorKind::IsA })); + return Err(Err::Error(Error { + input: i, + code: ErrorKind::IsA, + })); } Ok((i, si)) @@ -879,7 +883,8 @@ fn field_value_expr(i: &[u8]) -> IResult<&[u8], FieldAssignmentValue> { FieldValueExpression::Literal(LiteralExpression { value: l.into(), alias: None, - }).into() + }) + .into() }), ))(i) } @@ -1066,6 +1071,28 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> { many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i) } +pub fn insert_data_value_list(i: &[u8]) -> IResult<&[u8], Vec> { + many0(delimited(multispace0, insert_data_value, opt(ws_sep_comma)))(i) +} + +pub fn insert_data_value(i: &[u8]) -> IResult<&[u8], InsertDataValue> { + alt(( + map( + tuple(( + tag_no_case("DEFAULT"), + tag("("), + multispace0, + column_identifier_no_alias, + multispace0, + tag(")"), + )), + |(_, _, _, c, _, _)| InsertDataValue::ColumnDefault(c), + ), + map(tag_no_case("DEFAULT"), |_| InsertDataValue::Default), + map(literal, |l| InsertDataValue::Literal(l)), + ))(i) +} + // Parse a reference to a named schema.table, with an optional alias pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> { map( diff --git a/src/insert.rs b/src/insert.rs index f5ad9ae..8c13132 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,14 +4,17 @@ use std::fmt::{Display, Formatter}; use std::str; use column::Column; -use common::{assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, ws_sep_comma, Literal, FieldAssignmentValue}; +use common::{ + assignment_expr_list, field_list, insert_data_value_list, schema_table_reference, + statement_terminator, ws_sep_comma, FieldAssignmentValue, Literal, +}; use keywords::escape_if_keyword; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; use nom::multi::many1; use nom::sequence::{delimited, preceded, tuple}; -use nom::{IResult}; +use nom::IResult; use select::{nested_selection, Selection}; use table::Table; @@ -44,8 +47,8 @@ impl fmt::Display for InsertStatement { #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum InsertData { - RowValueList(Vec>), - ValueList(Vec>), + RowValueList(Vec>), + ValueList(Vec>), Select(Selection), } @@ -57,7 +60,7 @@ impl Default for InsertData { impl Display for InsertData { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let fmt_values = |start, vvl: &Vec>| { + let fmt_values = |start, vvl: &Vec>| { vvl.iter().fold(String::new(), |mut acc, vl| { if acc.len() > 0 { acc.push_str(", "); @@ -88,6 +91,37 @@ impl Display for InsertData { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum InsertDataValue { + Literal(Literal), + Default, + ColumnDefault(Column), +} + +impl Default for InsertDataValue { + fn default() -> Self { + Self::Default + } +} + +impl Display for InsertDataValue { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Literal(l) => { + write!(f, "{}", l.to_string()) + } + Self::Default => write!(f, "DEFAULT"), + Self::ColumnDefault(c) => write!(f, "DEFAULT({})", c), + } + } +} + +impl From for InsertDataValue { + fn from(l: Literal) -> Self { + Self::Literal(l) + } +} + fn fields(i: &[u8]) -> IResult<&[u8], Vec> { delimited( preceded(tag("("), multispace0), @@ -104,7 +138,7 @@ fn data(i: &[u8]) -> IResult<&[u8], InsertData> { multispace1, many1(delimited( tag("ROW("), - value_list, + insert_data_value_list, preceded(tag(")"), opt(ws_sep_comma)), )), )), @@ -115,7 +149,7 @@ fn data(i: &[u8]) -> IResult<&[u8], InsertData> { multispace0, many1(delimited( tag("("), - value_list, + insert_data_value_list, preceded(tag(")"), opt(ws_sep_comma)), )), )), @@ -181,22 +215,29 @@ mod tests { use common::ItemPlaceholder; use table::Table; use FieldDefinitionExpression::Col; - use {LiteralExpression, SelectStatement}; use FieldValueExpression; + use {LiteralExpression, SelectStatement}; #[test] fn simple_insert() { let qstring0 = "INSERT INTO users VALUES (42, \"test\");"; let qstring1 = "INSERT INTO users VALUES ROW(42, \"test\");"; + let qstring2 = "INSERT INTO users VALUES ROW(DEFAULT, \"test\");"; + let qstring3 = "INSERT INTO users VALUES ROW(DEFAULT(id), \"test\");"; let res0 = insertion(qstring0.as_bytes()); let res1 = insertion(qstring1.as_bytes()); + let res2 = insertion(qstring2.as_bytes()); + let res3 = insertion(qstring3.as_bytes()); assert_eq!( res0.unwrap().1, InsertStatement { table: Table::from("users"), fields: None, - data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); @@ -205,7 +246,34 @@ mod tests { InsertStatement { table: Table::from("users"), fields: None, - data: InsertData::RowValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::RowValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), + ..Default::default() + } + ); + assert_eq!( + res2.unwrap().1, + InsertStatement { + table: Table::from("users"), + fields: None, + data: InsertData::RowValueList(vec![vec![ + InsertDataValue::Default, + InsertDataValue::Literal("test".into()) + ]]), + ..Default::default() + } + ); + assert_eq!( + res3.unwrap().1, + InsertStatement { + table: Table::from("users"), + fields: None, + data: InsertData::RowValueList(vec![vec![ + InsertDataValue::ColumnDefault("id".into()), + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); @@ -221,7 +289,10 @@ mod tests { InsertStatement { table: Table::from(("db1", "users")), fields: None, - data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); @@ -238,10 +309,10 @@ mod tests { table: Table::from("users"), fields: None, data: InsertData::ValueList(vec![vec![ - 42.into(), - "test".into(), - "test".into(), - Literal::CurrentTimestamp, + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()), + InsertDataValue::Literal("test".into()), + InsertDataValue::Literal(Literal::CurrentTimestamp), ],]), ..Default::default() } @@ -258,7 +329,10 @@ mod tests { InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); @@ -269,19 +343,36 @@ mod tests { fn insert_without_spaces() { let qstring0 = "INSERT INTO users(id, name) VALUES(42, \"test\");"; let qstring1 = "INSERT INTO users(id, name) VALUESROW(42, \"test\");"; + let qstring2 = "INSERT INTO users(id, name) VALUES(DEFAULT(id), \"test\");"; let res0 = insertion(qstring0.as_bytes()); let res1 = insertion(qstring1.as_bytes()); + let res2 = insertion(qstring2.as_bytes()); assert_eq!( res0.unwrap().1, InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); assert!(res1.is_err()); + assert_eq!( + res2.unwrap().1, + InsertStatement { + table: Table::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::ColumnDefault("id".into()), + InsertDataValue::Literal("test".into()) + ]]), + ..Default::default() + } + ); } #[test] @@ -304,8 +395,14 @@ mod tests { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: InsertData::ValueList(vec![ - vec![42.into(), "test".into()], - vec![21.into(), "test2".into()], + vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()), + ], + vec![ + InsertDataValue::Literal(21.into()), + InsertDataValue::Literal("test2".into()), + ], ]), ..Default::default() }; @@ -314,8 +411,14 @@ mod tests { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: InsertData::RowValueList(vec![ - vec![42.into(), "test".into()], - vec![21.into(), "test2".into()], + vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()), + ], + vec![ + InsertDataValue::Literal(21.into()), + InsertDataValue::Literal("test2".into()), + ], ]), ..Default::default() }; @@ -339,8 +442,8 @@ mod tests { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: InsertData::ValueList(vec![vec![ - Literal::Placeholder(ItemPlaceholder::QuestionMark), - Literal::Placeholder(ItemPlaceholder::QuestionMark) + InsertDataValue::Literal(Literal::Placeholder(ItemPlaceholder::QuestionMark)), + InsertDataValue::Literal(Literal::Placeholder(ItemPlaceholder::QuestionMark)) ]]), ..Default::default() } @@ -368,8 +471,10 @@ mod tests { table: Table::from("keystores"), fields: Some(vec![Column::from("key"), Column::from("value")]), data: InsertData::ValueList(vec![vec![ - Literal::Placeholder(ItemPlaceholder::DollarNumber(1)), - Literal::Placeholder(ItemPlaceholder::ColonNumber(2)) + InsertDataValue::Literal(Literal::Placeholder(ItemPlaceholder::DollarNumber( + 1 + ))), + InsertDataValue::Literal(Literal::Placeholder(ItemPlaceholder::ColonNumber(2))) ]]), on_duplicate: Some(vec![( Column::from("value"), @@ -384,13 +489,12 @@ mod tests { table: Table::from("keystores"), fields: Some(vec![Column::from("key"), Column::from("value")]), data: InsertData::ValueList(vec![vec![ - Literal::Placeholder(ItemPlaceholder::DollarNumber(1)), - Literal::Placeholder(ItemPlaceholder::ColonNumber(2)) + InsertDataValue::Literal(Literal::Placeholder(ItemPlaceholder::DollarNumber( + 1 + ))), + InsertDataValue::Literal(Literal::Placeholder(ItemPlaceholder::ColonNumber(2))) ]]), - on_duplicate: Some(vec![( - Column::from("value"), - Column::from("value").into(), - ),]), + on_duplicate: Some(vec![(Column::from("value"), Column::from("value").into(),),]), ..Default::default() } ); @@ -400,15 +504,20 @@ mod tests { fn insert_with_leading_value_whitespace() { let qstring0 = "INSERT INTO users (id, name) VALUES ( 42, \"test\");"; let qstring1 = "INSERT INTO users (id, name) VALUES ROW( 42, \"test\");"; + let qstring2 = "INSERT INTO users (id, name) VALUES ROW( DEFAULT, \"test\");"; let res0 = insertion(qstring0.as_bytes()); let res1 = insertion(qstring1.as_bytes()); + let res2 = insertion(qstring2.as_bytes()); assert_eq!( res0.unwrap().1, InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); @@ -417,7 +526,22 @@ mod tests { InsertStatement { table: Table::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), - data: InsertData::RowValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::RowValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()) + ]]), + ..Default::default() + } + ); + assert_eq!( + res2.unwrap().1, + InsertStatement { + table: Table::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: InsertData::RowValueList(vec![vec![ + InsertDataValue::Default, + InsertDataValue::Literal("test".into()) + ]]), ..Default::default() } ); @@ -475,7 +599,8 @@ mod tests { FieldValueExpression::Literal(LiteralExpression { value: Literal::String("dupe".to_string()), alias: None, - }).into(), + }) + .into(), )]); assert_eq!(res0.unwrap().1, expected0); diff --git a/src/parser.rs b/src/parser.rs index c24429a..fadcb73 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -84,7 +84,7 @@ where #[cfg(test)] mod tests { use super::*; - use insert::InsertData; + use insert::{InsertData, InsertDataValue}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -99,7 +99,10 @@ mod tests { let expected = SqlQuery::Insert(InsertStatement { table: Table::from("users"), fields: None, - data: InsertData::ValueList(vec![vec![42.into(), "test".into()]]), + data: InsertData::ValueList(vec![vec![ + InsertDataValue::Literal(42.into()), + InsertDataValue::Literal("test".into()), + ]]), ..Default::default() }); let mut h0 = DefaultHasher::new(); diff --git a/src/update.rs b/src/update.rs index 101bc0a..1c77da5 100644 --- a/src/update.rs +++ b/src/update.rs @@ -72,8 +72,8 @@ mod tests { use condition::ConditionBase::*; use condition::ConditionExpression::*; use condition::ConditionTree; - use FieldValueExpression; use table::Table; + use FieldValueExpression; #[test] fn simple_update() { @@ -89,13 +89,15 @@ mod tests { fields: vec![ ( Column::from("id"), - FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))).into(), + FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))) + .into(), ), ( Column::from("name"), FieldValueExpression::Literal(LiteralExpression::from(Literal::from( "test", - ))).into(), + ))) + .into(), ), ], ..Default::default() @@ -106,14 +108,8 @@ mod tests { UpdateStatement { table: Table::from("users"), fields: vec![ - ( - Column::from("id"), - Column::from("new_id").into(), - ), - ( - Column::from("name"), - Column::from("old_name").into(), - ), + (Column::from("id"), Column::from("new_id").into(),), + (Column::from("name"), Column::from("old_name").into(),), ], ..Default::default() } @@ -138,13 +134,15 @@ mod tests { fields: vec![ ( Column::from("id"), - FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))).into(), + FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))) + .into(), ), ( Column::from("name"), FieldValueExpression::Literal(LiteralExpression::from(Literal::from( "test", - ))).into(), + ))) + .into(), ), ], where_clause: expected_where_cond, @@ -185,7 +183,8 @@ mod tests { integral: -19216, fractional: 5479744, } - ),)).into(), + ),)) + .into(), ),], where_clause: expected_where_cond, ..Default::default()