diff --git a/core/ast/__init__.py b/core/ast/__init__.py index 9d54b22..804646a 100644 --- a/core/ast/__init__.py +++ b/core/ast/__init__.py @@ -4,7 +4,7 @@ This module provides the node types and classes for representing SQL query structures. """ -from .node_type import NodeType +from .enums import NodeType from .node import ( Node, TableNode, @@ -18,6 +18,7 @@ SelectNode, FromNode, WhereNode, + JoinNode, GroupByNode, HavingNode, OrderByNode, @@ -40,9 +41,11 @@ 'SelectNode', 'FromNode', 'WhereNode', + 'JoinNode', 'GroupByNode', 'HavingNode', 'OrderByNode', + 'OrderByItemNode', 'LimitNode', 'OffsetNode', 'QueryNode' diff --git a/core/ast/enums.py b/core/ast/enums.py new file mode 100644 index 0000000..4a07d54 --- /dev/null +++ b/core/ast/enums.py @@ -0,0 +1,60 @@ +from enum import Enum + +# ============================================================================ +# Node Type Enumeration +# ============================================================================ + +class NodeType(Enum): + """Node type enumeration""" + + # Operands + TABLE = "table" + SUBQUERY = "subquery" + COLUMN = "column" + LITERAL = "literal" + # VarSQL specific + VAR = "var" + VARSET = "varset" + + # Operators + OPERATOR = "operator" + FUNCTION = "function" + + # Query structure + SELECT = "select" + FROM = "from" + WHERE = "where" + JOIN = "join" + GROUP_BY = "group_by" + HAVING = "having" + ORDER_BY = "order_by" + ORDER_BY_ITEM = "order_by_item" + LIMIT = "limit" + OFFSET = "offset" + QUERY = "query" + +# ============================================================================ +# Join Type Enumeration +# ============================================================================ + +class JoinType(Enum): + """Join type enumeration""" + INNER = "inner" + OUTER = "outer" + LEFT = "left" + RIGHT = "right" + FULL = "full" + CROSS = "cross" + NATURAL = "natural" + SEMI = "semi" + ANTI = "anti" + + +# ============================================================================ +# Sort Order Enumeration +# ============================================================================ + +class SortOrder(Enum): + """Sort order enum""" + ASC = "ASC" + DESC = "DESC" \ No newline at end of file diff --git a/core/ast/node.py b/core/ast/node.py index d27584b..befe799 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -2,7 +2,7 @@ from typing import List, Set, Optional from abc import ABC -from .node_type import NodeType +from .enums import NodeType, JoinType, SortOrder # ============================================================================ # Base Node Structure @@ -13,6 +13,31 @@ class Node(ABC): def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None): self.type = type self.children = children if children is not None else set() + + def __eq__(self, other): + if not isinstance(other, Node): + return False + if self.type != other.type: + return False + if len(self.children) != len(other.children): + return False + # Compare children + if isinstance(self.children, set) and isinstance(other.children, set): + return self.children == other.children + elif isinstance(self.children, list) and isinstance(other.children, list): + return self.children == other.children + else: + return False + + def __hash__(self): + # Make nodes hashable by using their type and a hash of their children + if isinstance(self.children, set): + # For sets, create a deterministic hash by sorting children by their string representation + children_hash = hash(tuple(sorted(self.children, key=lambda x: str(x)))) + else: + # For lists, just hash the tuple directly + children_hash = hash(tuple(self.children)) + return hash((self.type, children_hash)) # ============================================================================ @@ -25,6 +50,16 @@ def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): super().__init__(NodeType.TABLE, **kwargs) self.name = _name self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, TableNode): + return False + return (super().__eq__(other) and + self.name == other.name and + self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.alias)) # TODO - including query structure arguments (similar to QueryNode) in constructor. @@ -43,6 +78,17 @@ def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Opti self.alias = _alias self.parent_alias = _parent_alias self.parent = _parent + + def __eq__(self, other): + if not isinstance(other, ColumnNode): + return False + return (super().__eq__(other) and + self.name == other.name and + self.alias == other.alias and + self.parent_alias == other.parent_alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.alias, self.parent_alias)) class LiteralNode(Node): @@ -51,6 +97,15 @@ def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs): super().__init__(NodeType.LITERAL, **kwargs) self.value = _value + def __eq__(self, other): + if not isinstance(other, LiteralNode): + return False + return (super().__eq__(other) and + self.value == other.value) + + def __hash__(self): + return hash((super().__hash__(), self.value)) + class VarNode(Node): """VarSQL variable node""" @@ -72,16 +127,57 @@ def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwa children = [_left, _right] if _right else [_left] super().__init__(NodeType.OPERATOR, children=children, **kwargs) self.name = _name + + def __eq__(self, other): + if not isinstance(other, OperatorNode): + return False + return (super().__eq__(other) and + self.name == other.name) + + def __hash__(self): + return hash((super().__hash__(), self.name)) class FunctionNode(Node): """Function call node""" - def __init__(self, _name: str, _args: Optional[List[Node]] = None, **kwargs): + def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs): if _args is None: _args = [] super().__init__(NodeType.FUNCTION, children=_args, **kwargs) self.name = _name - + self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, FunctionNode): + return False + return (super().__eq__(other) and + self.name == other.name and + self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.alias)) + + +class JoinNode(Node): + """JOIN clause node""" + def __init__(self, _left_table: 'TableNode', _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): + children = [_left_table, _right_table] + if _on_condition: + children.append(_on_condition) + super().__init__(NodeType.JOIN, children=children, **kwargs) + self.left_table = _left_table + self.right_table = _right_table + self.join_type = _join_type + self.on_condition = _on_condition + + def __eq__(self, other): + if not isinstance(other, JoinNode): + return False + return (super().__eq__(other) and + self.join_type == other.join_type) + + def __hash__(self): + return hash((super().__hash__(), self.join_type)) # ============================================================================ # Query Structure Nodes @@ -89,20 +185,20 @@ def __init__(self, _name: str, _args: Optional[List[Node]] = None, **kwargs): class SelectNode(Node): """SELECT clause node""" - def __init__(self, _items: Set['Node'], **kwargs): + def __init__(self, _items: List['Node'], **kwargs): super().__init__(NodeType.SELECT, children=_items, **kwargs) # TODO - confine the valid NodeTypes as children of FromNode class FromNode(Node): """FROM clause node""" - def __init__(self, _sources: Set['Node'], **kwargs): + def __init__(self, _sources: List['Node'], **kwargs): super().__init__(NodeType.FROM, children=_sources, **kwargs) class WhereNode(Node): """WHERE clause node""" - def __init__(self, _predicates: Set['Node'], **kwargs): + def __init__(self, _predicates: List['Node'], **kwargs): super().__init__(NodeType.WHERE, children=_predicates, **kwargs) @@ -114,13 +210,28 @@ def __init__(self, _items: List['Node'], **kwargs): class HavingNode(Node): """HAVING clause node""" - def __init__(self, _predicates: Set['Node'], **kwargs): + def __init__(self, _predicates: List['Node'], **kwargs): super().__init__(NodeType.HAVING, children=_predicates, **kwargs) +class OrderByItemNode(Node): + """Single ORDER BY item""" + def __init__(self, _column: Node, _sort: SortOrder = SortOrder.ASC, **kwargs): + super().__init__(NodeType.ORDER_BY_ITEM, children=[_column], **kwargs) + self.sort = _sort + + def __eq__(self, other): + if not isinstance(other, OrderByItemNode): + return False + return (super().__eq__(other) and + self.sort == other.sort) + + def __hash__(self): + return hash((super().__hash__(), self.sort)) + class OrderByNode(Node): """ORDER BY clause node""" - def __init__(self, _items: List['Node'], **kwargs): + def __init__(self, _items: List[OrderByItemNode], **kwargs): super().__init__(NodeType.ORDER_BY, children=_items, **kwargs) @@ -129,6 +240,15 @@ class LimitNode(Node): def __init__(self, _limit: int, **kwargs): super().__init__(NodeType.LIMIT, **kwargs) self.limit = _limit + + def __eq__(self, other): + if not isinstance(other, LimitNode): + return False + return (super().__eq__(other) and + self.limit == other.limit) + + def __hash__(self): + return hash((super().__hash__(), self.limit)) class OffsetNode(Node): @@ -136,6 +256,15 @@ class OffsetNode(Node): def __init__(self, _offset: int, **kwargs): super().__init__(NodeType.OFFSET, **kwargs) self.offset = _offset + + def __eq__(self, other): + if not isinstance(other, OffsetNode): + return False + return (super().__eq__(other) and + self.offset == other.offset) + + def __hash__(self): + return hash((super().__hash__(), self.offset)) class QueryNode(Node): diff --git a/core/ast/node_type.py b/core/ast/node_type.py deleted file mode 100644 index 2bae729..0000000 --- a/core/ast/node_type.py +++ /dev/null @@ -1,32 +0,0 @@ -from enum import Enum - -# ============================================================================ -# Node Type Enumeration -# ============================================================================ - -class NodeType(Enum): - """Node type enumeration""" - - # Operands - TABLE = "table" - SUBQUERY = "subquery" - COLUMN = "column" - LITERAL = "literal" - # VarSQL specific - VAR = "var" - VARSET = "varset" - - # Operators - OPERATOR = "operator" - FUNCTION = "function" - - # Query structure - SELECT = "select" - FROM = "from" - WHERE = "where" - GROUP_BY = "group_by" - HAVING = "having" - ORDER_BY = "order_by" - LIMIT = "limit" - OFFSET = "offset" - QUERY = "query" diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 8b176f9..c3e7b61 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -5,7 +5,7 @@ LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode ) -from core.ast.node_type import NodeType +from core.ast.enums import NodeType, JoinType, SortOrder from data.queries import get_query parser = QueryParser()