From 9ea6f971624f02580662c04f2d9bba1b9d1c9042 Mon Sep 17 00:00:00 2001 From: Qiushi Bai Date: Wed, 15 Oct 2025 16:12:24 -0700 Subject: [PATCH] Adding two basic test cases for dividing the work of implementing parse and format functions --- tests/test_query_parser.py | 125 +++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 8b176f9..b411079 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -10,6 +10,131 @@ parser = QueryParser() + +def test_basic_parse(): + + # Construct input query text + sql = """ + SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + FROM employees e JOIN departments d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + + # Construct expected AST + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + dept_name = ColumnNode("name", "dept_name", "d") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + count_star = FunctionNode("COUNT", {ColumnNode("*")}) + count_alias = ColumnNode("emp_count") # This would be the alias for COUNT(*) + # SELECT clause + select_clause = SelectNode({emp_name, dept_name, count_star}) + # FROM clause (with implicit JOIN logic) + from_clause = FromNode({emp_table, dept_table}) + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode({where_condition}) + # GROUP BY clause + group_by_clause = GroupByNode({dept_id, dept_name}) + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode({having_condition}) + # ORDER BY clause + order_by_clause = OrderByNode({dept_name, count_alias}) + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + # Complete query + expected_ast = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + ast = parser.parse(sql) + + assert ast == expected_ast + + +def test_basic_format(): + + # Construct input AST + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + dept_name = ColumnNode("name", "dept_name", "d") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + count_star = FunctionNode("COUNT", {ColumnNode("*")}) + count_alias = ColumnNode("emp_count") # This would be the alias for COUNT(*) + # SELECT clause + select_clause = SelectNode({emp_name, dept_name, count_star}) + # FROM clause (with implicit JOIN logic) + from_clause = FromNode({emp_table, dept_table}) + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode({where_condition}) + # GROUP BY clause + group_by_clause = GroupByNode({dept_id, dept_name}) + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode({having_condition}) + # ORDER BY clause + order_by_clause = OrderByNode({dept_name, count_alias}) + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + # Complete query + ast = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + # Construct expected query text + expected_sql = """ + SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + FROM employees e JOIN departments d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + + sql = parser.format(ast) + + assert sql == expected_sql + + def test_parse_1(): query = get_query(1) sql = query['pattern']