From 43bf081ecf5e87683e9f0253e44b9de82ae77222 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Mon, 10 Nov 2025 10:35:34 -0500 Subject: [PATCH 1/8] parser complete --- core/query_parser.py | 244 ++++++++++++++++++++++++++++++++++++- tests/test_query_parser.py | 69 ++++++++++- 2 files changed, 308 insertions(+), 5 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index 1ac3796..3793d2a 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,16 +1,252 @@ -from core.ast.node import QueryNode +from core.ast.node import ( + Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode +) +import mo_sql_parsing as mosql class QueryParser: def parse(self, query: str) -> QueryNode: - # Implement parsing logic using self.rules - pass - # [1] Call mo_sql_parser # str -> Any (JSON) + mosql_ast = mosql.parse(query) # [2] Our new code # Any (JSON) -> AST (QueryNode) + self.aliases = {} + + select_clause = None + from_clause = None + where_clause = None + group_by_clause = None + having_clause = None + order_by_clause = None + limit_clause = None + offset_clause = None + + if 'select' in mosql_ast: + select_clause = self._parse_select(mosql_ast['select']) + if 'from' in mosql_ast: + from_clause = self._parse_from(mosql_ast['from']) + if 'where' in mosql_ast: + where_clause = self._parse_where(mosql_ast['where']) + if 'groupby' in mosql_ast: + group_by_clause = self._parse_group_by(mosql_ast['groupby']) + if 'having' in mosql_ast: + having_clause = self._parse_having(mosql_ast['having']) + if 'orderby' in mosql_ast: + order_by_clause = self._parse_order_by(mosql_ast['orderby']) + if 'limit' in mosql_ast: + limit_clause = LimitNode(mosql_ast['limit']) + if 'offset' in mosql_ast: + offset_clause = OffsetNode(mosql_ast['offset']) + + return QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + def _parse_select(self, select_list: list) -> SelectNode: + items = set() + for item in select_list: + if 'value' in item: + expression = self._parse_expression(item['value']) + # Handle alias + if 'name' in item: + alias = item['name'] + if isinstance(expression, ColumnNode): + expression.alias = alias + elif isinstance(expression, FunctionNode): + expression.alias = alias + # Add other types if needed + + self.aliases[alias] = expression + + items.add(expression) + + return SelectNode(items) + + def _parse_from(self, from_list: list) -> FromNode: + sources = set() + left_table = None + + for item in from_list: + if 'value' in item: + table_name = item['value'] + alias = item.get('name') + table_node = TableNode(table_name, alias) + # Track table alias + if alias: + self.aliases[alias] = table_node + + if left_table is None: + # First table becomes the left table + left_table = table_node + else: + # This shouldn't happen in normal SQL, but handle it + sources.add(table_node) + + elif 'join' in item: + if left_table is None: + raise ValueError("JOIN found without a left table") + + join_info = item['join'] + table_name = join_info['value'] + alias = join_info.get('name') + right_table = TableNode(table_name, alias) + # Track table alias + if alias: + self.aliases[alias] = right_table + + on_condition = None + if 'on' in item: + on_condition = self._parse_expression(item['on']) + + join_node = JoinNode(left_table, right_table, "INNER", on_condition) + sources.add(join_node) + # Reset for potential chained JOINs + left_table = None + + # Add any remaining left table + if left_table is not None: + sources.add(left_table) + + return FromNode(sources) + + def _parse_where(self, where_dict: dict) -> WhereNode: + predicates = set() + predicates.add(self._parse_expression(where_dict)) + return WhereNode(predicates) + + def _parse_group_by(self, group_by_list: list) -> GroupByNode: + items = [] + for item in group_by_list: + if 'value' in item: + expr = self._parse_expression(item['value']) + # Resolve aliases + expr = self._resolve_aliases(expr) + items.append(expr) + return GroupByNode(items) + + def _parse_having(self, having_dict: dict) -> HavingNode: + predicates = set() + expr = self._parse_expression(having_dict) + # Check if this expression references an aliased function from SELECT + expr = self._resolve_aliases(expr) + + predicates.add(expr) + return HavingNode(predicates) + + def _parse_order_by(self, order_by_list: list) -> OrderByNode: + items = [] + for item in order_by_list: + if 'value' in item: + value = item['value'] + # Check if this is an alias reference + if value in self.aliases: + column = self.aliases[value] + else: + # Parse normally for other cases + column = self._parse_expression(value) + items.append(column) + return OrderByNode(items) + + def _resolve_aliases(self, expr: Node) -> Node: + if isinstance(expr, OperatorNode): + # Recursively resolve aliases in operator operands + left = self._resolve_aliases(expr.children[0]) + right = self._resolve_aliases(expr.children[1]) + return OperatorNode(left, expr.name, right) + elif isinstance(expr, FunctionNode): + # Check if this function matches an aliased function from SELECT + if expr.alias is None: + for alias, aliased_expr in self.aliases.items(): + if isinstance(aliased_expr, FunctionNode): + if (expr.name == aliased_expr.name and + len(expr.children) == len(aliased_expr.children) and + all(expr.children[i] == aliased_expr.children[i] + for i in range(len(expr.children)))): + # This function matches an aliased one, use the alias + expr.alias = alias + break + return expr + elif isinstance(expr, ColumnNode): + # Check if this column matches an aliased column from SELECT + if expr.alias is None: + for alias, aliased_expr in self.aliases.items(): + if isinstance(aliased_expr, ColumnNode): + if (expr.name == aliased_expr.name and + expr.parent_alias == aliased_expr.parent_alias): + # This column matches an aliased one, use the alias + expr.alias = alias + break + return expr + else: + return expr + + def _parse_expression(self, expr) -> Node: + if isinstance(expr, str): + # Column reference + if '.' in expr: + parts = expr.split('.') + if len(parts) == 2: + return ColumnNode(parts[1], _parent_alias=parts[0]) + else: + return ColumnNode(expr) + else: + return ColumnNode(expr) + elif isinstance(expr, (int, float, bool)): + # Literal value + return LiteralNode(expr) + elif isinstance(expr, dict): + # Complex expression + if 'count' in expr: + # COUNT function + if expr['count'] == '*': + return FunctionNode("COUNT", [ColumnNode("*")]) + else: + return FunctionNode("COUNT", [self._parse_expression(expr['count'])]) + elif 'eq' in expr: + # Equality operator + left = self._parse_expression(expr['eq'][0]) + right = self._parse_expression(expr['eq'][1]) + return OperatorNode(left, "=", right) + elif 'gt' in expr: + # Greater than operator + left = self._parse_expression(expr['gt'][0]) + right = self._parse_expression(expr['gt'][1]) + return OperatorNode(left, ">", right) + elif 'lt' in expr: + # Less than operator + left = self._parse_expression(expr['lt'][0]) + right = self._parse_expression(expr['lt'][1]) + return OperatorNode(left, "<", right) + elif 'and' in expr: + # AND operator + conditions = expr['and'] + if len(conditions) == 2: + left = self._parse_expression(conditions[0]) + right = self._parse_expression(conditions[1]) + return OperatorNode(left, "AND", right) + else: + # Handle multiple AND conditions + result = self._parse_expression(conditions[0]) + for condition in conditions[1:]: + result = OperatorNode(result, "AND", self._parse_expression(condition)) + return result + else: + # Unknown expression type + return LiteralNode(str(expr)) + else: + return LiteralNode(str(expr)) + def format(self, query: QueryNode) -> str: # Implement formatting logic to convert AST back to SQL string diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index c3e7b61..0d9f143 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -3,13 +3,80 @@ from core.ast.node import ( QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode + OrderByNode, LimitNode, OffsetNode, JoinNode ) from core.ast.enums import NodeType, JoinType, SortOrder from data.queries import get_query parser = QueryParser() + +def test_basic_parse(): + + # Construct input query text + sql = """ + SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + FROM employees e JOIN departments d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + + # Construct expected AST + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + + dept_name = ColumnNode("name", _alias="dept_name", _parent_alias="d") + dept_id = ColumnNode("id", _parent_alias="d") + + count_star = FunctionNode("COUNT", _alias="emp_count", _args=[ColumnNode("*")]) + + # SELECT clause + select_clause = SelectNode({emp_name, dept_name, count_star}) + # FROM clause with JOIN + join_condition = OperatorNode(emp_dept_id, "=", dept_id) + join_node = JoinNode(emp_table, dept_table, "INNER", join_condition) + from_clause = FromNode({join_node}) + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode({where_condition}) + # GROUP BY clause + group_by_clause = GroupByNode([dept_id, dept_name]) + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode({having_condition}) + # ORDER BY clause -> desc and asc are not supported yet!! + order_by_clause = OrderByNode([dept_name, count_star]) + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + # Complete query + expected_ast = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + ast = parser.parse(sql) + + assert ast == expected_ast + + def test_parse_1(): query = get_query(1) sql = query['pattern'] From 74ba8637b8999ee36d13f5a62490ae52da34f377 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:47:36 -0500 Subject: [PATCH 2/8] clean up --- core/query_parser.py | 357 +++++++++++++++++++++++-------------- tests/test_query_parser.py | 13 +- 2 files changed, 233 insertions(+), 137 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index 3793d2a..6f03e87 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,11 +1,31 @@ from core.ast.node import ( Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode + OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode ) +from core.ast.enums import JoinType, SortOrder import mo_sql_parsing as mosql class QueryParser: + @staticmethod + def normalize_to_list(value): + """Normalize mo_sql_parsing output to a list format. + + mo_sql_parsing returns: + - list when multiple items + - dict when single item with structure + - str when single simple value + + This normalizes all cases to a list. + """ + if value is None: + return [] + elif isinstance(value, list): + return value + elif isinstance(value, (dict, str)): + return [value] + else: + return [value] def parse(self, query: str) -> QueryNode: # [1] Call mo_sql_parser @@ -26,17 +46,17 @@ def parse(self, query: str) -> QueryNode: offset_clause = None if 'select' in mosql_ast: - select_clause = self._parse_select(mosql_ast['select']) + select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select'])) if 'from' in mosql_ast: - from_clause = self._parse_from(mosql_ast['from']) + from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from'])) if 'where' in mosql_ast: - where_clause = self._parse_where(mosql_ast['where']) + where_clause = self.parse_where(mosql_ast['where']) if 'groupby' in mosql_ast: - group_by_clause = self._parse_group_by(mosql_ast['groupby']) + group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby'])) if 'having' in mosql_ast: - having_clause = self._parse_having(mosql_ast['having']) + having_clause = self.parse_having(mosql_ast['having']) if 'orderby' in mosql_ast: - order_by_clause = self._parse_order_by(mosql_ast['orderby']) + order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby'])) if 'limit' in mosql_ast: limit_clause = LimitNode(mosql_ast['limit']) if 'offset' in mosql_ast: @@ -53,116 +73,159 @@ def parse(self, query: str) -> QueryNode: _offset=offset_clause ) - def _parse_select(self, select_list: list) -> SelectNode: + def parse_select(self, select_list: list) -> SelectNode: items = set() for item in select_list: - if 'value' in item: - expression = self._parse_expression(item['value']) - # Handle alias + if isinstance(item, dict) and 'value' in item: + expression = self.parse_expression(item['value']) + # Handle alias - set for any node that has alias attribute if 'name' in item: alias = item['name'] - if isinstance(expression, ColumnNode): - expression.alias = alias - elif isinstance(expression, FunctionNode): + if hasattr(expression, 'alias'): expression.alias = alias - # Add other types if needed - self.aliases[alias] = expression items.add(expression) + else: + # Handle direct expression (string, int, etc.) + expression = self.parse_expression(item) + items.add(expression) return SelectNode(items) - def _parse_from(self, from_list: list) -> FromNode: + def parse_from(self, from_list: list) -> FromNode: sources = set() - left_table = None + left_source = None # Can be a table or the result of a previous join for item in from_list: - if 'value' in item: - table_name = item['value'] - alias = item.get('name') - table_node = TableNode(table_name, alias) - # Track table alias - if alias: - self.aliases[alias] = table_node + # Check for JOIN first (before checking for 'value') + if isinstance(item, dict): + # Look for any join key + join_key = next((k for k in item.keys() if 'join' in k.lower()), None) - if left_table is None: - # First table becomes the left table - left_table = table_node + if join_key: + # This is a JOIN + if left_source is None: + raise ValueError("JOIN found without a left table") + + join_info = item[join_key] + # Handle both string and dict join_info + if isinstance(join_info, str): + table_name = join_info + alias = None + else: + table_name = join_info['value'] if isinstance(join_info, dict) else join_info + alias = join_info.get('name') if isinstance(join_info, dict) else None + + right_table = TableNode(table_name, alias) + # Track table alias + if alias: + self.aliases[alias] = right_table + + on_condition = None + if 'on' in item: + on_condition = self.parse_expression(item['on']) + + # Create join node - left_source might be a table or a previous join + join_type = self.parse_join_type(join_key) + join_node = JoinNode(left_source, right_table, join_type, on_condition) + # The result of this JOIN becomes the new left source for potential next JOIN + left_source = join_node + elif 'value' in item: + # This is a table reference + table_name = item['value'] + alias = item.get('name') + table_node = TableNode(table_name, alias) + # Track table alias + if alias: + self.aliases[alias] = table_node + + if left_source is None: + # First table becomes the left source + left_source = table_node + else: + # Multiple tables without explicit JOIN (cross join) + sources.add(table_node) + elif isinstance(item, str): + # Simple string table name + table_node = TableNode(item) + if left_source is None: + left_source = table_node else: - # This shouldn't happen in normal SQL, but handle it sources.add(table_node) - - elif 'join' in item: - if left_table is None: - raise ValueError("JOIN found without a left table") - - join_info = item['join'] - table_name = join_info['value'] - alias = join_info.get('name') - right_table = TableNode(table_name, alias) - # Track table alias - if alias: - self.aliases[alias] = right_table - - on_condition = None - if 'on' in item: - on_condition = self._parse_expression(item['on']) - - join_node = JoinNode(left_table, right_table, "INNER", on_condition) - sources.add(join_node) - # Reset for potential chained JOINs - left_table = None - # Add any remaining left table - if left_table is not None: - sources.add(left_table) + # Add the final left source (which might be a single table or chain of joins) + if left_source is not None: + sources.add(left_source) return FromNode(sources) - def _parse_where(self, where_dict: dict) -> WhereNode: + def parse_where(self, where_dict: dict) -> WhereNode: predicates = set() - predicates.add(self._parse_expression(where_dict)) + predicates.add(self.parse_expression(where_dict)) return WhereNode(predicates) - def _parse_group_by(self, group_by_list: list) -> GroupByNode: + def parse_group_by(self, group_by_list: list) -> GroupByNode: items = [] for item in group_by_list: - if 'value' in item: - expr = self._parse_expression(item['value']) + if isinstance(item, dict) and 'value' in item: + expr = self.parse_expression(item['value']) # Resolve aliases - expr = self._resolve_aliases(expr) + expr = self.resolve_aliases(expr) items.append(expr) + else: + # Handle direct expression (string, int, etc.) + expr = self.parse_expression(item) + expr = self.resolve_aliases(expr) + items.append(expr) + return GroupByNode(items) - def _parse_having(self, having_dict: dict) -> HavingNode: + def parse_having(self, having_dict: dict) -> HavingNode: predicates = set() - expr = self._parse_expression(having_dict) + expr = self.parse_expression(having_dict) # Check if this expression references an aliased function from SELECT - expr = self._resolve_aliases(expr) + expr = self.resolve_aliases(expr) predicates.add(expr) + return HavingNode(predicates) - def _parse_order_by(self, order_by_list: list) -> OrderByNode: + def parse_order_by(self, order_by_list: list) -> OrderByNode: items = [] for item in order_by_list: - if 'value' in item: + if isinstance(item, dict) and 'value' in item: value = item['value'] # Check if this is an alias reference - if value in self.aliases: + if isinstance(value, str) and value in self.aliases: column = self.aliases[value] else: # Parse normally for other cases - column = self._parse_expression(value) - items.append(column) + column = self.parse_expression(value) + + # Get sort order (default is ASC) + sort_order = SortOrder.ASC + if 'sort' in item: + sort_str = item['sort'].upper() + if sort_str == 'DESC': + sort_order = SortOrder.DESC + + # Wrap in OrderByItemNode + order_by_item = OrderByItemNode(column, sort_order) + items.append(order_by_item) + else: + # Handle direct expression (string, int, etc.) + column = self.parse_expression(item) + order_by_item = OrderByItemNode(column, SortOrder.ASC) + items.append(order_by_item) + return OrderByNode(items) - def _resolve_aliases(self, expr: Node) -> Node: + def resolve_aliases(self, expr: Node) -> Node: if isinstance(expr, OperatorNode): # Recursively resolve aliases in operator operands - left = self._resolve_aliases(expr.children[0]) - right = self._resolve_aliases(expr.children[1]) + left = self.resolve_aliases(expr.children[0]) + right = self.resolve_aliases(expr.children[1]) return OperatorNode(left, expr.name, right) elif isinstance(expr, FunctionNode): # Check if this function matches an aliased function from SELECT @@ -191,69 +254,101 @@ def _resolve_aliases(self, expr: Node) -> Node: else: return expr - def _parse_expression(self, expr) -> Node: + def parse_expression(self, expr) -> Node: if isinstance(expr, str): # Column reference if '.' in expr: - parts = expr.split('.') - if len(parts) == 2: - return ColumnNode(parts[1], _parent_alias=parts[0]) - else: - return ColumnNode(expr) - else: - return ColumnNode(expr) - elif isinstance(expr, (int, float, bool)): - # Literal value + parts = expr.split('.', 1) + return ColumnNode(parts[1], _parent_alias=parts[0]) + return ColumnNode(expr) + + if isinstance(expr, (int, float, bool)): return LiteralNode(expr) - elif isinstance(expr, dict): - # Complex expression - if 'count' in expr: - # COUNT function - if expr['count'] == '*': - return FunctionNode("COUNT", [ColumnNode("*")]) - else: - return FunctionNode("COUNT", [self._parse_expression(expr['count'])]) - elif 'eq' in expr: - # Equality operator - left = self._parse_expression(expr['eq'][0]) - right = self._parse_expression(expr['eq'][1]) - return OperatorNode(left, "=", right) - elif 'gt' in expr: - # Greater than operator - left = self._parse_expression(expr['gt'][0]) - right = self._parse_expression(expr['gt'][1]) - return OperatorNode(left, ">", right) - elif 'lt' in expr: - # Less than operator - left = self._parse_expression(expr['lt'][0]) - right = self._parse_expression(expr['lt'][1]) - return OperatorNode(left, "<", right) - elif 'and' in expr: - # AND operator - conditions = expr['and'] - if len(conditions) == 2: - left = self._parse_expression(conditions[0]) - right = self._parse_expression(conditions[1]) - return OperatorNode(left, "AND", right) - else: - # Handle multiple AND conditions - result = self._parse_expression(conditions[0]) - for condition in conditions[1:]: - result = OperatorNode(result, "AND", self._parse_expression(condition)) + + if isinstance(expr, list): + # List literals (for IN clauses) - convert to tuple for hashability + parsed = tuple(self.parse_expression(item) for item in expr) + return LiteralNode(parsed) + + if isinstance(expr, dict): + # Special cases first + if 'all_columns' in expr: + return ColumnNode('*') + if 'literal' in expr: + return LiteralNode(expr['literal']) + + # Skip metadata keys + skip_keys = {'value', 'name', 'on', 'sort'} + + # Find the operator/function key + for key in expr.keys(): + if key in skip_keys: + continue + + value = expr[key] + op_name = self.normalize_operator_name(key) + + # Pattern 1: Binary/N-ary operator with list of operands + if isinstance(value, list): + if len(value) == 0: + return LiteralNode(None) + if len(value) == 1: + return self.parse_expression(value[0]) + + # Parse all operands + operands = [self.parse_expression(v) for v in value] + + # Chain multiple operands with the same operator + result = operands[0] + for operand in operands[1:]: + result = OperatorNode(result, op_name, operand) return result - else: - # Unknown expression type - return LiteralNode(str(expr)) - else: - return LiteralNode(str(expr)) - - - def format(self, query: QueryNode) -> str: - # Implement formatting logic to convert AST back to SQL string - pass - - # [1] Our new code - # AST (QueryNode) -> JSON - - # [2] Call mo_sql_format - # Any (JSON) -> str \ No newline at end of file + + # Pattern 2: Unary operator + if key == 'not': + return OperatorNode(self.parse_expression(value), 'NOT') + + # Pattern 3: Function call + # Special case: COUNT(*), SUM(*), etc. + if value == '*': + return FunctionNode(op_name, [ColumnNode('*')]) + + # Regular function + args = [self.parse_expression(value)] + return FunctionNode(op_name, args) + + # No valid key found + import json + return LiteralNode(json.dumps(expr, sort_keys=True)) + + # Other types + return LiteralNode(expr) + + @staticmethod + def normalize_operator_name(key: str) -> str: + """Convert mo_sql_parsing operator keys to SQL operator names.""" + mapping = { + 'eq': '=', 'neq': '!=', 'ne': '!=', + 'gt': '>', 'gte': '>=', + 'lt': '<', 'lte': '<=', + 'and': 'AND', 'or': 'OR', + } + return mapping.get(key.lower(), key.upper()) + + @staticmethod + def parse_join_type(join_key: str) -> JoinType: + """Extract JoinType from mo_sql_parsing join key.""" + key_lower = join_key.lower().replace(' ', '_') + + if 'inner' in key_lower: + return JoinType.INNER + elif 'left' in key_lower: + return JoinType.LEFT + elif 'right' in key_lower: + return JoinType.RIGHT + elif 'full' in key_lower: + return JoinType.FULL + elif 'cross' in key_lower: + return JoinType.CROSS + + return JoinType.INNER # By default \ No newline at end of file diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 0d9f143..382c3cf 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -1,11 +1,10 @@ -import mo_sql_parsing as mosql from core.query_parser import QueryParser from core.ast.node import ( QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, JoinNode + OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode ) -from core.ast.enums import NodeType, JoinType, SortOrder +from core.ast.enums import JoinType, SortOrder from data.queries import get_query parser = QueryParser() @@ -43,7 +42,7 @@ def test_basic_parse(): select_clause = SelectNode({emp_name, dept_name, count_star}) # FROM clause with JOIN join_condition = OperatorNode(emp_dept_id, "=", dept_id) - join_node = JoinNode(emp_table, dept_table, "INNER", join_condition) + join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition) from_clause = FromNode({join_node}) # WHERE clause salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) @@ -55,8 +54,10 @@ def test_basic_parse(): # HAVING clause having_condition = OperatorNode(count_star, ">", LiteralNode(2)) having_clause = HavingNode({having_condition}) - # ORDER BY clause -> desc and asc are not supported yet!! - order_by_clause = OrderByNode([dept_name, count_star]) + # ORDER BY clause + order_by_item1 = OrderByItemNode(dept_name, SortOrder.ASC) + order_by_item2 = OrderByItemNode(count_star, SortOrder.DESC) + order_by_clause = OrderByNode([order_by_item1, order_by_item2]) # LIMIT and OFFSET limit_clause = LimitNode(10) offset_clause = OffsetNode(5) From fbbbbfcfc4424450677ac90f8b191d19c70ac46e Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:17:04 -0500 Subject: [PATCH 3/8] update select from set to list --- core/query_parser.py | 6 +++--- tests/test_query_parser.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index 6f03e87..1b8174d 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -74,7 +74,7 @@ def parse(self, query: str) -> QueryNode: ) def parse_select(self, select_list: list) -> SelectNode: - items = set() + items = [] for item in select_list: if isinstance(item, dict) and 'value' in item: expression = self.parse_expression(item['value']) @@ -85,11 +85,11 @@ def parse_select(self, select_list: list) -> SelectNode: expression.alias = alias self.aliases[alias] = expression - items.add(expression) + items.append(expression) else: # Handle direct expression (string, int, etc.) expression = self.parse_expression(item) - items.add(expression) + items.append(expression) return SelectNode(items) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 382c3cf..489e28d 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -39,7 +39,7 @@ def test_basic_parse(): count_star = FunctionNode("COUNT", _alias="emp_count", _args=[ColumnNode("*")]) # SELECT clause - select_clause = SelectNode({emp_name, dept_name, count_star}) + select_clause = SelectNode([emp_name, dept_name, count_star]) # FROM clause with JOIN join_condition = OperatorNode(emp_dept_id, "=", dept_id) join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition) From 0bb311384c9c58443c23dc3efdc2cf991b29926f Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:21:36 -0500 Subject: [PATCH 4/8] Apply suggestions from code review first batch fix Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- core/query_parser.py | 7 +++++-- tests/test_query_parser.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index 1b8174d..c7e25b9 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -25,7 +25,10 @@ def normalize_to_list(value): elif isinstance(value, (dict, str)): return [value] else: - return [value] + raise TypeError( + f"normalize_to_list: Unexpected type {type(value).__name__} for value {value!r}. " + "Expected None, list, dict, or str." + ) def parse(self, query: str) -> QueryNode: # [1] Call mo_sql_parser @@ -106,7 +109,7 @@ def parse_from(self, from_list: list) -> FromNode: if join_key: # This is a JOIN if left_source is None: - raise ValueError("JOIN found without a left table") + raise ValueError(f"JOIN found without a left table. join_key={join_key}, item={item}") join_info = item[join_key] # Handle both string and dict join_info diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 489e28d..7d6c35e 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -11,6 +11,9 @@ def test_basic_parse(): + """ + Test parsing of a complex SQL query with JOINs, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, and OFFSET clauses. + """ # Construct input query text sql = """ @@ -43,17 +46,17 @@ def test_basic_parse(): # FROM clause with JOIN join_condition = OperatorNode(emp_dept_id, "=", dept_id) join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition) - from_clause = FromNode({join_node}) + from_clause = FromNode([join_node]) # WHERE clause salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) where_condition = OperatorNode(salary_condition, "AND", age_condition) - where_clause = WhereNode({where_condition}) + where_clause = WhereNode([where_condition]) # GROUP BY clause group_by_clause = GroupByNode([dept_id, dept_name]) # HAVING clause having_condition = OperatorNode(count_star, ">", LiteralNode(2)) - having_clause = HavingNode({having_condition}) + having_clause = HavingNode([having_condition]) # ORDER BY clause order_by_item1 = OrderByItemNode(dept_name, SortOrder.ASC) order_by_item2 = OrderByItemNode(count_star, SortOrder.DESC) From d2ce8de10786d155c3c7b6ea95a78e54feb204bb Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:36:15 -0500 Subject: [PATCH 5/8] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- core/query_parser.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index c7e25b9..fe6ed4b 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -185,12 +185,12 @@ def parse_group_by(self, group_by_list: list) -> GroupByNode: return GroupByNode(items) def parse_having(self, having_dict: dict) -> HavingNode: - predicates = set() + predicates = [] expr = self.parse_expression(having_dict) # Check if this expression references an aliased function from SELECT expr = self.resolve_aliases(expr) - predicates.add(expr) + predicates.append(expr) return HavingNode(predicates) @@ -269,9 +269,9 @@ def parse_expression(self, expr) -> Node: return LiteralNode(expr) if isinstance(expr, list): - # List literals (for IN clauses) - convert to tuple for hashability - parsed = tuple(self.parse_expression(item) for item in expr) - return LiteralNode(parsed) + # List literals (for IN clauses) + parsed = [self.parse_expression(item) for item in expr] + return parsed if isinstance(expr, dict): # Special cases first From f95a6595e76105494b87124241e83587aef9008d Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:49:45 -0500 Subject: [PATCH 6/8] batch update 3 --- core/query_parser.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index fe6ed4b..bc6519f 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -3,8 +3,10 @@ LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode ) +# TODO: implement SubqueryNode, VarNode, VarSetNode from core.ast.enums import JoinType, SortOrder import mo_sql_parsing as mosql +import json class QueryParser: @staticmethod @@ -164,8 +166,8 @@ def parse_from(self, from_list: list) -> FromNode: return FromNode(sources) def parse_where(self, where_dict: dict) -> WhereNode: - predicates = set() - predicates.add(self.parse_expression(where_dict)) + predicates = [] + predicates.append(self.parse_expression(where_dict)) return WhereNode(predicates) def parse_group_by(self, group_by_list: list) -> GroupByNode: @@ -314,14 +316,13 @@ def parse_expression(self, expr) -> Node: # Pattern 3: Function call # Special case: COUNT(*), SUM(*), etc. if value == '*': - return FunctionNode(op_name, [ColumnNode('*')]) + return FunctionNode(op_name, _args=[ColumnNode('*')]) # Regular function args = [self.parse_expression(value)] - return FunctionNode(op_name, args) + return FunctionNode(op_name, _args=args) # No valid key found - import json return LiteralNode(json.dumps(expr, sort_keys=True)) # Other types From 7614b6c0297f35f90b34a9c093fd56bef74c8a9e Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Sat, 20 Dec 2025 17:31:03 -0500 Subject: [PATCH 7/8] update inconsistency --- core/ast/node.py | 4 ++-- core/query_parser.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/core/ast/node.py b/core/ast/node.py index befe799..5e6231e 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Set, Optional +from typing import List, Set, Optional, Union from abc import ABC from .enums import NodeType, JoinType, SortOrder @@ -160,7 +160,7 @@ def __hash__(self): class JoinNode(Node): """JOIN clause node""" - def __init__(self, _left_table: 'TableNode', _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): + def __init__(self, _left_table: Union['TableNode', 'JoinNode'], _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): children = [_left_table, _right_table] if _on_condition: children.append(_on_condition) diff --git a/core/query_parser.py b/core/query_parser.py index bc6519f..766766a 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -99,7 +99,7 @@ def parse_select(self, select_list: list) -> SelectNode: return SelectNode(items) def parse_from(self, from_list: list) -> FromNode: - sources = set() + sources = [] left_source = None # Can be a table or the result of a previous join for item in from_list: @@ -150,18 +150,18 @@ def parse_from(self, from_list: list) -> FromNode: left_source = table_node else: # Multiple tables without explicit JOIN (cross join) - sources.add(table_node) + sources.append(table_node) elif isinstance(item, str): # Simple string table name table_node = TableNode(item) if left_source is None: left_source = table_node else: - sources.add(table_node) + sources.append(table_node) # Add the final left source (which might be a single table or chain of joins) if left_source is not None: - sources.add(left_source) + sources.append(left_source) return FromNode(sources) @@ -231,6 +231,13 @@ def resolve_aliases(self, expr: Node) -> Node: # Recursively resolve aliases in operator operands left = self.resolve_aliases(expr.children[0]) right = self.resolve_aliases(expr.children[1]) + # TODO - When the OperatorNode is created with unary operators (e.g., 'NOT'), + # only one operand is passed (line 309). + # The resolve_aliases method assumes binary operators and always accesses + # children[0] and children[1] (lines 227-228). + # This will cause an IndexError when resolving aliases for unary operators. + # Add a check for the number of children before accessing indices. + # (and test it once we have such test cases) return OperatorNode(left, expr.name, right) elif isinstance(expr, FunctionNode): # Check if this function matches an aliased function from SELECT From a9c4214eff21f3d1aa1d500f796eb0382456ac17 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Sun, 21 Dec 2025 19:49:21 -0500 Subject: [PATCH 8/8] thread safe alias tracking --- core/query_parser.py | 69 ++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index 766766a..deee3c1 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -39,7 +39,8 @@ def parse(self, query: str) -> QueryNode: # [2] Our new code # Any (JSON) -> AST (QueryNode) - self.aliases = {} + # Aliases dictionary + aliases = {} select_clause = None from_clause = None @@ -51,17 +52,17 @@ def parse(self, query: str) -> QueryNode: offset_clause = None if 'select' in mosql_ast: - select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select'])) + select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select']), aliases) if 'from' in mosql_ast: - from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from'])) + from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from']), aliases) if 'where' in mosql_ast: - where_clause = self.parse_where(mosql_ast['where']) + where_clause = self.parse_where(mosql_ast['where'], aliases) if 'groupby' in mosql_ast: - group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby'])) + group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby']), aliases) if 'having' in mosql_ast: - having_clause = self.parse_having(mosql_ast['having']) + having_clause = self.parse_having(mosql_ast['having'], aliases) if 'orderby' in mosql_ast: - order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby'])) + order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby']), aliases) if 'limit' in mosql_ast: limit_clause = LimitNode(mosql_ast['limit']) if 'offset' in mosql_ast: @@ -78,7 +79,7 @@ def parse(self, query: str) -> QueryNode: _offset=offset_clause ) - def parse_select(self, select_list: list) -> SelectNode: + def parse_select(self, select_list: list, aliases: dict) -> SelectNode: items = [] for item in select_list: if isinstance(item, dict) and 'value' in item: @@ -88,7 +89,7 @@ def parse_select(self, select_list: list) -> SelectNode: alias = item['name'] if hasattr(expression, 'alias'): expression.alias = alias - self.aliases[alias] = expression + aliases[alias] = expression items.append(expression) else: @@ -98,7 +99,7 @@ def parse_select(self, select_list: list) -> SelectNode: return SelectNode(items) - def parse_from(self, from_list: list) -> FromNode: + def parse_from(self, from_list: list, aliases: dict) -> FromNode: sources = [] left_source = None # Can be a table or the result of a previous join @@ -125,7 +126,7 @@ def parse_from(self, from_list: list) -> FromNode: right_table = TableNode(table_name, alias) # Track table alias if alias: - self.aliases[alias] = right_table + aliases[alias] = right_table on_condition = None if 'on' in item: @@ -143,7 +144,7 @@ def parse_from(self, from_list: list) -> FromNode: table_node = TableNode(table_name, alias) # Track table alias if alias: - self.aliases[alias] = table_node + aliases[alias] = table_node if left_source is None: # First table becomes the left source @@ -165,45 +166,45 @@ def parse_from(self, from_list: list) -> FromNode: return FromNode(sources) - def parse_where(self, where_dict: dict) -> WhereNode: + def parse_where(self, where_dict: dict, aliases: dict) -> WhereNode: predicates = [] predicates.append(self.parse_expression(where_dict)) return WhereNode(predicates) - def parse_group_by(self, group_by_list: list) -> GroupByNode: + def parse_group_by(self, group_by_list: list, aliases: dict) -> GroupByNode: items = [] for item in group_by_list: if isinstance(item, dict) and 'value' in item: expr = self.parse_expression(item['value']) # Resolve aliases - expr = self.resolve_aliases(expr) + expr = self.resolve_aliases(expr, aliases) items.append(expr) else: # Handle direct expression (string, int, etc.) expr = self.parse_expression(item) - expr = self.resolve_aliases(expr) + expr = self.resolve_aliases(expr, aliases) items.append(expr) return GroupByNode(items) - def parse_having(self, having_dict: dict) -> HavingNode: + def parse_having(self, having_dict: dict, aliases: dict) -> HavingNode: predicates = [] expr = self.parse_expression(having_dict) # Check if this expression references an aliased function from SELECT - expr = self.resolve_aliases(expr) + expr = self.resolve_aliases(expr, aliases) predicates.append(expr) return HavingNode(predicates) - def parse_order_by(self, order_by_list: list) -> OrderByNode: + def parse_order_by(self, order_by_list: list, aliases: dict) -> OrderByNode: items = [] for item in order_by_list: if isinstance(item, dict) and 'value' in item: value = item['value'] # Check if this is an alias reference - if isinstance(value, str) and value in self.aliases: - column = self.aliases[value] + if isinstance(value, str) and value in aliases: + column = aliases[value] else: # Parse normally for other cases column = self.parse_expression(value) @@ -226,23 +227,23 @@ def parse_order_by(self, order_by_list: list) -> OrderByNode: return OrderByNode(items) - def resolve_aliases(self, expr: Node) -> Node: + def resolve_aliases(self, expr: Node, aliases: dict) -> Node: if isinstance(expr, OperatorNode): # Recursively resolve aliases in operator operands - left = self.resolve_aliases(expr.children[0]) - right = self.resolve_aliases(expr.children[1]) - # TODO - When the OperatorNode is created with unary operators (e.g., 'NOT'), - # only one operand is passed (line 309). - # The resolve_aliases method assumes binary operators and always accesses - # children[0] and children[1] (lines 227-228). - # This will cause an IndexError when resolving aliases for unary operators. - # Add a check for the number of children before accessing indices. - # (and test it once we have such test cases) - return OperatorNode(left, expr.name, right) + if len(expr.children) >= 2: + left = self.resolve_aliases(expr.children[0], aliases) + right = self.resolve_aliases(expr.children[1], aliases) + return OperatorNode(left, expr.name, right) + elif len(expr.children) == 1: + # Unary operator (e.g., NOT) + operand = self.resolve_aliases(expr.children[0], aliases) + return OperatorNode(operand, expr.name) + else: + raise ValueError(f"OperatorNode has {len(expr.children)} children, expected 2 for binary operators or 1 for unary operators") elif isinstance(expr, FunctionNode): # Check if this function matches an aliased function from SELECT if expr.alias is None: - for alias, aliased_expr in self.aliases.items(): + for alias, aliased_expr in aliases.items(): if isinstance(aliased_expr, FunctionNode): if (expr.name == aliased_expr.name and len(expr.children) == len(aliased_expr.children) and @@ -255,7 +256,7 @@ def resolve_aliases(self, expr: Node) -> Node: elif isinstance(expr, ColumnNode): # Check if this column matches an aliased column from SELECT if expr.alias is None: - for alias, aliased_expr in self.aliases.items(): + for alias, aliased_expr in aliases.items(): if isinstance(aliased_expr, ColumnNode): if (expr.name == aliased_expr.name and expr.parent_alias == aliased_expr.parent_alias):