diff --git a/core/ast/node.py b/core/ast/node.py index befe799..5e6231e 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Set, Optional +from typing import List, Set, Optional, Union from abc import ABC from .enums import NodeType, JoinType, SortOrder @@ -160,7 +160,7 @@ def __hash__(self): class JoinNode(Node): """JOIN clause node""" - def __init__(self, _left_table: 'TableNode', _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): + def __init__(self, _left_table: Union['TableNode', 'JoinNode'], _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): children = [_left_table, _right_table] if _on_condition: children.append(_on_condition) diff --git a/core/query_parser.py b/core/query_parser.py index 1ac3796..deee3c1 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,23 +1,366 @@ -from core.ast.node import QueryNode +from core.ast.node import ( + Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode +) +# TODO: implement SubqueryNode, VarNode, VarSetNode +from core.ast.enums import JoinType, SortOrder +import mo_sql_parsing as mosql +import json class QueryParser: + @staticmethod + def normalize_to_list(value): + """Normalize mo_sql_parsing output to a list format. + + mo_sql_parsing returns: + - list when multiple items + - dict when single item with structure + - str when single simple value + + This normalizes all cases to a list. + """ + if value is None: + return [] + elif isinstance(value, list): + return value + elif isinstance(value, (dict, str)): + return [value] + else: + raise TypeError( + f"normalize_to_list: Unexpected type {type(value).__name__} for value {value!r}. " + "Expected None, list, dict, or str." + ) def parse(self, query: str) -> QueryNode: - # Implement parsing logic using self.rules - pass - # [1] Call mo_sql_parser # str -> Any (JSON) + mosql_ast = mosql.parse(query) # [2] Our new code # Any (JSON) -> AST (QueryNode) + # Aliases dictionary + aliases = {} + + select_clause = None + from_clause = None + where_clause = None + group_by_clause = None + having_clause = None + order_by_clause = None + limit_clause = None + offset_clause = None + + if 'select' in mosql_ast: + select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select']), aliases) + if 'from' in mosql_ast: + from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from']), aliases) + if 'where' in mosql_ast: + where_clause = self.parse_where(mosql_ast['where'], aliases) + if 'groupby' in mosql_ast: + group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby']), aliases) + if 'having' in mosql_ast: + having_clause = self.parse_having(mosql_ast['having'], aliases) + if 'orderby' in mosql_ast: + order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby']), aliases) + if 'limit' in mosql_ast: + limit_clause = LimitNode(mosql_ast['limit']) + if 'offset' in mosql_ast: + offset_clause = OffsetNode(mosql_ast['offset']) + + return QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + def parse_select(self, select_list: list, aliases: dict) -> SelectNode: + items = [] + for item in select_list: + if isinstance(item, dict) and 'value' in item: + expression = self.parse_expression(item['value']) + # Handle alias - set for any node that has alias attribute + if 'name' in item: + alias = item['name'] + if hasattr(expression, 'alias'): + expression.alias = alias + aliases[alias] = expression + + items.append(expression) + else: + # Handle direct expression (string, int, etc.) + expression = self.parse_expression(item) + items.append(expression) + + return SelectNode(items) + + def parse_from(self, from_list: list, aliases: dict) -> FromNode: + sources = [] + left_source = None # Can be a table or the result of a previous join + + for item in from_list: + # Check for JOIN first (before checking for 'value') + if isinstance(item, dict): + # Look for any join key + join_key = next((k for k in item.keys() if 'join' in k.lower()), None) + + if join_key: + # This is a JOIN + if left_source is None: + raise ValueError(f"JOIN found without a left table. join_key={join_key}, item={item}") + + join_info = item[join_key] + # Handle both string and dict join_info + if isinstance(join_info, str): + table_name = join_info + alias = None + else: + table_name = join_info['value'] if isinstance(join_info, dict) else join_info + alias = join_info.get('name') if isinstance(join_info, dict) else None + + right_table = TableNode(table_name, alias) + # Track table alias + if alias: + aliases[alias] = right_table + + on_condition = None + if 'on' in item: + on_condition = self.parse_expression(item['on']) + + # Create join node - left_source might be a table or a previous join + join_type = self.parse_join_type(join_key) + join_node = JoinNode(left_source, right_table, join_type, on_condition) + # The result of this JOIN becomes the new left source for potential next JOIN + left_source = join_node + elif 'value' in item: + # This is a table reference + table_name = item['value'] + alias = item.get('name') + table_node = TableNode(table_name, alias) + # Track table alias + if alias: + aliases[alias] = table_node + + if left_source is None: + # First table becomes the left source + left_source = table_node + else: + # Multiple tables without explicit JOIN (cross join) + sources.append(table_node) + elif isinstance(item, str): + # Simple string table name + table_node = TableNode(item) + if left_source is None: + left_source = table_node + else: + sources.append(table_node) + + # Add the final left source (which might be a single table or chain of joins) + if left_source is not None: + sources.append(left_source) + + return FromNode(sources) + + def parse_where(self, where_dict: dict, aliases: dict) -> WhereNode: + predicates = [] + predicates.append(self.parse_expression(where_dict)) + return WhereNode(predicates) + + def parse_group_by(self, group_by_list: list, aliases: dict) -> GroupByNode: + items = [] + for item in group_by_list: + if isinstance(item, dict) and 'value' in item: + expr = self.parse_expression(item['value']) + # Resolve aliases + expr = self.resolve_aliases(expr, aliases) + items.append(expr) + else: + # Handle direct expression (string, int, etc.) + expr = self.parse_expression(item) + expr = self.resolve_aliases(expr, aliases) + items.append(expr) - def format(self, query: QueryNode) -> str: - # Implement formatting logic to convert AST back to SQL string - pass + return GroupByNode(items) + + def parse_having(self, having_dict: dict, aliases: dict) -> HavingNode: + predicates = [] + expr = self.parse_expression(having_dict) + # Check if this expression references an aliased function from SELECT + expr = self.resolve_aliases(expr, aliases) + + predicates.append(expr) - # [1] Our new code - # AST (QueryNode) -> JSON + return HavingNode(predicates) + + def parse_order_by(self, order_by_list: list, aliases: dict) -> OrderByNode: + items = [] + for item in order_by_list: + if isinstance(item, dict) and 'value' in item: + value = item['value'] + # Check if this is an alias reference + if isinstance(value, str) and value in aliases: + column = aliases[value] + else: + # Parse normally for other cases + column = self.parse_expression(value) + + # Get sort order (default is ASC) + sort_order = SortOrder.ASC + if 'sort' in item: + sort_str = item['sort'].upper() + if sort_str == 'DESC': + sort_order = SortOrder.DESC + + # Wrap in OrderByItemNode + order_by_item = OrderByItemNode(column, sort_order) + items.append(order_by_item) + else: + # Handle direct expression (string, int, etc.) + column = self.parse_expression(item) + order_by_item = OrderByItemNode(column, SortOrder.ASC) + items.append(order_by_item) - # [2] Call mo_sql_format - # Any (JSON) -> str \ No newline at end of file + return OrderByNode(items) + + def resolve_aliases(self, expr: Node, aliases: dict) -> Node: + if isinstance(expr, OperatorNode): + # Recursively resolve aliases in operator operands + if len(expr.children) >= 2: + left = self.resolve_aliases(expr.children[0], aliases) + right = self.resolve_aliases(expr.children[1], aliases) + return OperatorNode(left, expr.name, right) + elif len(expr.children) == 1: + # Unary operator (e.g., NOT) + operand = self.resolve_aliases(expr.children[0], aliases) + return OperatorNode(operand, expr.name) + else: + raise ValueError(f"OperatorNode has {len(expr.children)} children, expected 2 for binary operators or 1 for unary operators") + elif isinstance(expr, FunctionNode): + # Check if this function matches an aliased function from SELECT + if expr.alias is None: + for alias, aliased_expr in aliases.items(): + if isinstance(aliased_expr, FunctionNode): + if (expr.name == aliased_expr.name and + len(expr.children) == len(aliased_expr.children) and + all(expr.children[i] == aliased_expr.children[i] + for i in range(len(expr.children)))): + # This function matches an aliased one, use the alias + expr.alias = alias + break + return expr + elif isinstance(expr, ColumnNode): + # Check if this column matches an aliased column from SELECT + if expr.alias is None: + for alias, aliased_expr in aliases.items(): + if isinstance(aliased_expr, ColumnNode): + if (expr.name == aliased_expr.name and + expr.parent_alias == aliased_expr.parent_alias): + # This column matches an aliased one, use the alias + expr.alias = alias + break + return expr + else: + return expr + + def parse_expression(self, expr) -> Node: + if isinstance(expr, str): + # Column reference + if '.' in expr: + parts = expr.split('.', 1) + return ColumnNode(parts[1], _parent_alias=parts[0]) + return ColumnNode(expr) + + if isinstance(expr, (int, float, bool)): + return LiteralNode(expr) + + if isinstance(expr, list): + # List literals (for IN clauses) + parsed = [self.parse_expression(item) for item in expr] + return parsed + + if isinstance(expr, dict): + # Special cases first + if 'all_columns' in expr: + return ColumnNode('*') + if 'literal' in expr: + return LiteralNode(expr['literal']) + + # Skip metadata keys + skip_keys = {'value', 'name', 'on', 'sort'} + + # Find the operator/function key + for key in expr.keys(): + if key in skip_keys: + continue + + value = expr[key] + op_name = self.normalize_operator_name(key) + + # Pattern 1: Binary/N-ary operator with list of operands + if isinstance(value, list): + if len(value) == 0: + return LiteralNode(None) + if len(value) == 1: + return self.parse_expression(value[0]) + + # Parse all operands + operands = [self.parse_expression(v) for v in value] + + # Chain multiple operands with the same operator + result = operands[0] + for operand in operands[1:]: + result = OperatorNode(result, op_name, operand) + return result + + # Pattern 2: Unary operator + if key == 'not': + return OperatorNode(self.parse_expression(value), 'NOT') + + # Pattern 3: Function call + # Special case: COUNT(*), SUM(*), etc. + if value == '*': + return FunctionNode(op_name, _args=[ColumnNode('*')]) + + # Regular function + args = [self.parse_expression(value)] + return FunctionNode(op_name, _args=args) + + # No valid key found + return LiteralNode(json.dumps(expr, sort_keys=True)) + + # Other types + return LiteralNode(expr) + + @staticmethod + def normalize_operator_name(key: str) -> str: + """Convert mo_sql_parsing operator keys to SQL operator names.""" + mapping = { + 'eq': '=', 'neq': '!=', 'ne': '!=', + 'gt': '>', 'gte': '>=', + 'lt': '<', 'lte': '<=', + 'and': 'AND', 'or': 'OR', + } + return mapping.get(key.lower(), key.upper()) + + @staticmethod + def parse_join_type(join_key: str) -> JoinType: + """Extract JoinType from mo_sql_parsing join key.""" + key_lower = join_key.lower().replace(' ', '_') + + if 'inner' in key_lower: + return JoinType.INNER + elif 'left' in key_lower: + return JoinType.LEFT + elif 'right' in key_lower: + return JoinType.RIGHT + elif 'full' in key_lower: + return JoinType.FULL + elif 'cross' in key_lower: + return JoinType.CROSS + + return JoinType.INNER # By default \ No newline at end of file diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index c3e7b61..7d6c35e 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -1,15 +1,86 @@ -import mo_sql_parsing as mosql from core.query_parser import QueryParser from core.ast.node import ( QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode + OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode ) -from core.ast.enums import NodeType, JoinType, SortOrder +from core.ast.enums import JoinType, SortOrder from data.queries import get_query parser = QueryParser() + +def test_basic_parse(): + """ + Test parsing of a complex SQL query with JOINs, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, and OFFSET clauses. + """ + + # Construct input query text + sql = """ + SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + FROM employees e JOIN departments d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + + # Construct expected AST + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + + dept_name = ColumnNode("name", _alias="dept_name", _parent_alias="d") + dept_id = ColumnNode("id", _parent_alias="d") + + count_star = FunctionNode("COUNT", _alias="emp_count", _args=[ColumnNode("*")]) + + # SELECT clause + select_clause = SelectNode([emp_name, dept_name, count_star]) + # FROM clause with JOIN + join_condition = OperatorNode(emp_dept_id, "=", dept_id) + join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition) + from_clause = FromNode([join_node]) + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode([where_condition]) + # GROUP BY clause + group_by_clause = GroupByNode([dept_id, dept_name]) + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode([having_condition]) + # ORDER BY clause + order_by_item1 = OrderByItemNode(dept_name, SortOrder.ASC) + order_by_item2 = OrderByItemNode(count_star, SortOrder.DESC) + order_by_clause = OrderByNode([order_by_item1, order_by_item2]) + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + # Complete query + expected_ast = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + ast = parser.parse(sql) + + assert ast == expected_ast + + def test_parse_1(): query = get_query(1) sql = query['pattern']