diff --git a/examples/eql/cache.md b/examples/eql/cache.md index 78a3810..b05a70f 100644 --- a/examples/eql/cache.md +++ b/examples/eql/cache.md @@ -28,7 +28,8 @@ from dataclasses import dataclass from typing_extensions import List -from krrood.entity_query_language.entity import entity, an, let, contains, Symbol +from krrood.entity_query_language.entity import entity, let, contains, Symbol +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/comparators.md b/examples/eql/comparators.md index dabe2fe..f85534f 100644 --- a/examples/eql/comparators.md +++ b/examples/eql/comparators.md @@ -21,9 +21,10 @@ from dataclasses import dataclass from typing_extensions import List from krrood.entity_query_language.entity import ( - entity, an, let, Symbol, + entity, let, Symbol, in_, contains, not_, and_, or_, ) +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/domain_mapping.md b/examples/eql/domain_mapping.md index 57b3ca8..a57a217 100644 --- a/examples/eql/domain_mapping.md +++ b/examples/eql/domain_mapping.md @@ -29,12 +29,11 @@ from typing_extensions import List, Dict from krrood.entity_query_language.entity import ( entity, set_of, - an, let, flatten, Symbol, ) - +from krrood.entity_query_language.quantify_entity import an @dataclass class Body(Symbol): diff --git a/examples/eql/eql_for_sql_experts.md b/examples/eql/eql_for_sql_experts.md index 4d8f037..7d0ad62 100644 --- a/examples/eql/eql_for_sql_experts.md +++ b/examples/eql/eql_for_sql_experts.md @@ -55,7 +55,8 @@ from dataclasses import dataclass, field from typing_extensions import List -from krrood.entity_query_language.entity import let, Symbol, entity, an, and_, in_, contains, set_of +from krrood.entity_query_language.entity import let, Symbol, entity, and_, in_, contains, set_of +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/intro.md b/examples/eql/intro.md index 83db5ce..708fb7f 100644 --- a/examples/eql/intro.md +++ b/examples/eql/intro.md @@ -28,7 +28,8 @@ from dataclasses import dataclass from typing_extensions import List -from krrood.entity_query_language.entity import entity, an, let, contains, Symbol +from krrood.entity_query_language.entity import entity, let, contains, Symbol +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/logical_operators.md b/examples/eql/logical_operators.md index 5261853..f27aefa 100644 --- a/examples/eql/logical_operators.md +++ b/examples/eql/logical_operators.md @@ -21,7 +21,8 @@ from dataclasses import dataclass from typing_extensions import List -from krrood.entity_query_language.entity import entity, an, or_, Symbol, let, not_, and_ +from krrood.entity_query_language.entity import entity, or_, Symbol, let, not_, and_ +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/match.md b/examples/eql/match.md index ec9a60f..927d56f 100644 --- a/examples/eql/match.md +++ b/examples/eql/match.md @@ -21,12 +21,17 @@ The following example shows how nested patterns translate into an equivalent manual query built with `entity(...)` and predicates. ```{code-cell} ipython3 +from krrood.entity_query_language.symbol_graph import SymbolGraph from dataclasses import dataclass from typing_extensions import List from krrood.entity_query_language.entity import ( - let, entity, the, - match, entity_matching, Symbol, + let, entity, Symbol, +) +from krrood.entity_query_language.quantify_entity import the, an +from krrood.entity_query_language.match import ( + match, + entity_matching, ) from krrood.entity_query_language.predicate import HasType @@ -61,7 +66,19 @@ class FixedConnection(Connection): @dataclass class World: connections: List[Connection] + +@dataclass +class Drawer(Symbol): + handle: Handle + container: Container + + +@dataclass +class Cabinet(Symbol): + container: Container + drawers: List[Drawer] +SymbolGraph() # Build a small world with a few connections c1 = Container("Container1") @@ -123,3 +140,131 @@ Notes: - Use `entity_matching` for the outer pattern when a domain is involved; inner attributes use `match`. - Nested `match(...)` can be composed arbitrarily deep following your object graph. - `entity_matching` is a syntactic sugar over the explicit `entity` + predicates form, so both are interchangeable. + +## Selecting inner objects with `select()` + +Use `select(Type)` when you want the matched inner objects to appear in the result. The evaluation then +returns a mapping from the selected variables to the concrete objects (a unification dictionary). + +```{code-cell} ipython3 +from krrood.entity_query_language.match import select + +container, handle = select(Container), select(Handle) +fixed_connection_query = the( + entity_matching(FixedConnection, world.connections)( + parent=container(name="Container1"), + child=handle(name="Handle1"), + ) +) + +answers = fixed_connection_query.evaluate() +print(answers[container].name, answers[handle].name) +``` + +## Existential matches in collections with `match_any()` + +When matching a container-like attribute (for example, a list), use `match_any(pattern)` to express that +at least one element of the collection should satisfy the given pattern. + +Below we add two simple view classes and build a small scene of drawers and a cabinet. + +```{code-cell} ipython3 +from krrood.entity_query_language.match import match_any + +# Build a simple set of views +drawer1 = Drawer(handle=h1, container=c1) +drawer2 = Drawer(handle=Handle("OtherHandle"), container=other_c) +cabinet1 = Cabinet(container=c1, drawers=[drawer1, drawer2]) +cabinet2 = Cabinet(container=other_c, drawers=[drawer2]) +views = [drawer1, drawer2, cabinet1, cabinet2] + +# Query: find the cabinet that has any drawer from the set {drawer1, drawer2} +cabinet_query = an(entity_matching(Cabinet, views)(drawers=match_any([drawer1, drawer2]))) + +found_cabinets = list(cabinet_query.evaluate()) +assert len(found_cabinets) == 2 +print(found_cabinets[0].container.name, found_cabinets[0].drawers[0].handle.name) +print(found_cabinets[1].container.name, found_cabinets[1].drawers[0].handle.name) +``` + +## Selecting elements from collections with `select_any()` + +If you want to retrieve a specific element from a collection attribute while matching, use `select_any(Type)`. +It behaves like `match_any(Type)` but also selects the matched element so you can access it in the result. + +```{code-cell} ipython3 +from krrood.entity_query_language.match import select_any + +selected_drawers = select_any([drawer1, drawer2]) +# Query: find the cabinet that has any drawer from the set {drawer1, drawer2} +cabinet_query = an(entity_matching(Cabinet, views)(drawers=selected_drawers)) + +ans = list(cabinet_query.evaluate()) +assert len(ans) == 2 +print(ans) +``` + +## Selecting inner objects with `select()` + +Use `select(Type)` when you want the matched inner objects to appear in the result. The evaluation then +returns a mapping from the selected variables to the concrete objects (a unification dictionary). + +```{code-cell} ipython3 +from krrood.entity_query_language.match import select + +container, handle = select(Container), select(Handle) +fixed_connection_query = the( + entity_matching(FixedConnection, world.connections)( + parent=container(name="Container1"), + child=handle(name="Handle1"), + ) +) + +answers = fixed_connection_query.evaluate() +print(answers[container].name, answers[handle].name) +``` + +## Existential matches in collections with `match_any()` + +When having multiple possible matches, and you care only if at least the attribute matches one possibility, use +`match_any(IterableOfPossibleValues)` to express that +at least one element of the collection should satisfy the given pattern. + +Below we add two simple view classes and build a small scene of drawers and a cabinet. + +```{code-cell} ipython3 +from krrood.entity_query_language.match import match_any + +# Build a simple set of views +drawer1 = Drawer(handle=h1, container=c1) +drawer2 = Drawer(handle=Handle("OtherHandle"), container=other_c) +cabinet1 = Cabinet(container=c1, drawers=[drawer1, drawer2]) +cabinet2 = Cabinet(container=other_c, drawers=[drawer2]) +views = [drawer1, drawer2, cabinet1, cabinet2] + +# Query: find the cabinet that has any drawer from the set {drawer1, drawer2} +cabinet_query = an(entity_matching(Cabinet, views)(drawers=match_any([drawer1, drawer2]))) + +found_cabinets = list(cabinet_query.evaluate()) +assert len(found_cabinets) == 2 +print(found_cabinets[0].container.name, found_cabinets[0].drawers[0].handle.name) +print(found_cabinets[1].container.name, found_cabinets[1].drawers[0].handle.name) +``` + +## Selecting elements from collections with `select_any()` + +If you want to retrieve a specific element from a collection attribute while matching, use `select_any(Type)`. +It behaves like `match_any(Type)` but also selects the matched element so you can access it in the result. + +```{code-cell} ipython3 +from krrood.entity_query_language.match import select_any, entity_selection + +selected_drawers = select_any([drawer1, drawer2]) +# Query: find the cabinet that has any drawer from the set {drawer1, drawer2} +cabinet = entity_selection(Cabinet, views) +cabinet_query = an(cabinet(drawers=selected_drawers)) + +ans = list(cabinet_query.evaluate()) +assert len(ans) == 2 +print(ans) +``` diff --git a/examples/eql/predicate_and_symbolic_function.md b/examples/eql/predicate_and_symbolic_function.md index 98f697e..4030147 100644 --- a/examples/eql/predicate_and_symbolic_function.md +++ b/examples/eql/predicate_and_symbolic_function.md @@ -26,8 +26,9 @@ Lets first define our model and some sample data. from dataclasses import dataclass from typing_extensions import List -from krrood.entity_query_language.entity import entity, let, an, Symbol +from krrood.entity_query_language.entity import entity, let, Symbol from krrood.entity_query_language.predicate import Predicate, symbolic_function +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/result_quantifiers.md b/examples/eql/result_quantifiers.md index b360b6a..3b74f4a 100644 --- a/examples/eql/result_quantifiers.md +++ b/examples/eql/result_quantifiers.md @@ -27,7 +27,8 @@ from dataclasses import dataclass from typing_extensions import List -from krrood.entity_query_language.entity import entity, let, the, Symbol, an +from krrood.entity_query_language.entity import entity, let, Symbol +from krrood.entity_query_language.quantify_entity import an, the from krrood.entity_query_language.result_quantification_constraint import AtLeast, AtMost, Exactly, Range from krrood.entity_query_language.failures import MultipleSolutionFound, LessThanExpectedNumberOfSolutions, GreaterThanExpectedNumberOfSolutions @@ -81,7 +82,7 @@ Below we reuse the same `World` and `Body` setup from above. The world contains exactly two bodies, so all the following examples will evaluate successfully. ```{code-cell} ipython3 -# Require at least two results +# Require at least one result query = an( entity(body := let(Body, domain=world.bodies)), quantification=AtLeast(1), diff --git a/examples/eql/writing_queries.md b/examples/eql/writing_queries.md index 14fc754..2613cd2 100644 --- a/examples/eql/writing_queries.md +++ b/examples/eql/writing_queries.md @@ -35,14 +35,15 @@ This approach ensures that your class definitions remain pure and decoupled from outside the explicit symbolic context. Consequently, your classes can focus exclusively on their domain logic, leading to better adherence to the [Single Responsibility Principle](https://realpython.com/solid-principles-python/#single-responsibility-principle-srp). -Here is a query that does work due to the missing `let` statement: +Here is a query example that finds all bodies in a world whose name starts with "B": ```{code-cell} ipython3 from dataclasses import dataclass from typing_extensions import List -from krrood.entity_query_language.entity import entity, an, let, Symbol +from krrood.entity_query_language.entity import entity, let, Symbol +from krrood.entity_query_language.quantify_entity import an @dataclass diff --git a/examples/eql/writing_rule_trees.md b/examples/eql/writing_rule_trees.md index b5da65b..b258b33 100644 --- a/examples/eql/writing_rule_trees.md +++ b/examples/eql/writing_rule_trees.md @@ -24,7 +24,8 @@ Lets define our domain model and build a small world. We will then build a rule instances to the world. ```{code-cell} ipython3 -from krrood.entity_query_language.entity import entity, an, let, and_, Symbol, inference +from krrood.entity_query_language.entity import entity, let, and_, Symbol, inference +from krrood.entity_query_language.quantify_entity import an from krrood.entity_query_language.rule import refinement, alternative from krrood.entity_query_language.conclusion import Add diff --git a/scripts/test_documentation.sh b/scripts/test_documentation.sh old mode 100644 new mode 100755 diff --git a/src/krrood/class_diagrams/class_diagram.py b/src/krrood/class_diagrams/class_diagram.py index b425d7c..3fef988 100644 --- a/src/krrood/class_diagrams/class_diagram.py +++ b/src/krrood/class_diagrams/class_diagram.py @@ -471,7 +471,7 @@ def get_wrapped_class(self, clazz: Type) -> Optional[WrappedClass]: except KeyError: raise ClassIsUnMappedInClassDiagram(clazz) - def add_node(self, clazz: WrappedClass): + def add_node(self, clazz: Union[Type, WrappedClass]): """ Adds a new node to the dependency graph for the specified wrapped class. @@ -481,6 +481,12 @@ class to the wrapped class. :param clazz: The wrapped class object to be added to the dependency graph. """ + try: + clazz = self.get_wrapped_class(clazz) + except ClassIsUnMappedInClassDiagram: + clazz = WrappedClass(clazz) + if clazz.index is not None: + return clazz.index = self._dependency_graph.add_node(clazz) clazz._class_diagram = self self._cls_wrapped_cls_map[clazz.clazz] = clazz diff --git a/src/krrood/entity_query_language/entity.py b/src/krrood/entity_query_language/entity.py index 852f052..d7c192b 100644 --- a/src/krrood/entity_query_language/entity.py +++ b/src/krrood/entity_query_language/entity.py @@ -1,9 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass, field -from functools import cached_property - - from .hashed_data import T from .symbol_graph import SymbolGraph from .utils import is_iterable @@ -18,21 +14,17 @@ Optional, Union, Iterable, - Dict, - Generic, Type, Tuple, List, Callable, - TypeVar, + TYPE_CHECKING, ) from .symbolic import ( SymbolicExpression, Entity, SetOf, - The, - An, AND, Comparator, chained_logic, @@ -44,79 +36,22 @@ ForAll, Exists, Literal, - ResultQuantifier, ) -from .result_quantification_constraint import ResultQuantificationConstraint from .predicate import ( Predicate, # type: ignore Symbol, # type: ignore - HasType, ) +if TYPE_CHECKING: + pass + ConditionType = Union[SymbolicExpression, bool, Predicate] """ The possible types for conditions. """ -EntityType = Union[ - SetOf[T], Entity[T], T, Iterable[T], Type[T] -] # include Match[T] after moving match to a module @bass -""" -The possible types for entities. -""" - - -def an( - entity_: EntityType, - quantification: Optional[ResultQuantificationConstraint] = None, -) -> Union[An[T], T]: - """ - Select a single element satisfying the given entity description. - - :param entity_: An entity or a set expression to quantify over. - :param quantification: Optional quantification constraint. - :return: A quantifier representing "an" element. - :rtype: An[T] - """ - return _quantify_entity(An, entity_, _quantification_constraint_=quantification) - - -a = an -""" -This is an alias to accommodate for words not starting with vowels. -""" - - -def the( - entity_: EntityType, -) -> Union[The[T], T]: - """ - Select the unique element satisfying the given entity description. - - :param entity_: An entity or a set expression to quantify over. - :return: A quantifier representing "an" element. - :rtype: The[T] - """ - return _quantify_entity(The, entity_) - - -def _quantify_entity( - quantifier: Type[ResultQuantifier], entity_: EntityType, **quantifier_kwargs -) -> Union[ResultQuantifier[T], T]: - """ - Apply the given quantifier to the given entity. - - :param quantifier: The quantifier to apply. - :param entity_: The entity to quantify. - :param quantifier_kwargs: Keyword arguments to pass to the quantifier. - :return: The quantified entity. - """ - if isinstance(entity_, Match): - entity_ = entity_.expression - return quantifier(entity_, **quantifier_kwargs) - def entity( selected_variable: T, @@ -267,7 +202,9 @@ def not_(operand: SymbolicExpression): return operand._invert_() -def contains(container: Union[Iterable, CanBehaveLikeAVariable[T]], item: Any): +def contains( + container: Union[Iterable, CanBehaveLikeAVariable[T]], item: Any +) -> Comparator: """ Check whether a container contains an item. @@ -345,111 +282,3 @@ def inference( return lambda **kwargs: Variable( _type_=type_, _name__=type_.__name__, _kwargs_=kwargs, _is_inferred_=True ) - - -@dataclass -class Match(Generic[T]): - """ - Construct a query that looks for the pattern provided by the type and the keyword arguments. - """ - - type_: Type[T] - """ - The type of the variable. - """ - kwargs: Dict[str, Any] - """ - The keyword arguments to match against. - """ - variable: CanBehaveLikeAVariable[T] = field(init=False) - """ - The created variable from the type and kwargs. - """ - conditions: List[ConditionType] = field(init=False, default_factory=list) - """ - The conditions that define the match. - """ - - def _resolve(self, variable: Optional[CanBehaveLikeAVariable] = None): - """ - Resolve the match by creating the variable and conditions expressions. - - :param variable: An optional pre-existing variable to use for the match; if not provided, a new variable will be created. - :return: - """ - self.variable = variable if variable else self._create_variable() - for k, v in self.kwargs.items(): - attr = getattr(self.variable, k) - if isinstance(v, Match): - v._resolve(attr) - self.conditions.append(HasType(attr, v.type_)) - self.conditions.extend(v.conditions) - else: - self.conditions.append(attr == v) - - def _create_variable(self) -> Variable[T]: - """ - Create a variable with the given type. - """ - return let(self.type_, None) - - @cached_property - def expression(self) -> Entity[T]: - """ - Return the entity expression corresponding to the match query. - """ - self._resolve() - return entity(self.variable, *self.conditions) - - -@dataclass -class MatchEntity(Match[T]): - """ - A match that can also take a domain and should be used as the outermost match in a nested match statement. - This is because the inner match statements derive their domain from the outer match as they are basically attributes - of the outer match variable. - """ - - domain: DomainType - """ - The domain to use for the variable created by the match. - """ - - def _create_variable(self) -> Variable[T]: - """ - Create a variable with the given type and domain. - """ - return let(self.type_, self.domain) - - -def match(type_: Type[T]) -> Union[Type[T], Callable[..., Match[T]]]: - """ - This returns a factory function that creates a Match instance that looks for the pattern provided by the type and the - keyword arguments. - - :param type_: The type of the variable (i.e., The class you want to instantiate). - :return: The factory function for creating the match query. - """ - - def match_factory(**kwargs) -> Match[T]: - return Match(type_, kwargs) - - return match_factory - - -def entity_matching( - type_: Type[T], domain: DomainType -) -> Union[Type[T], Callable[..., MatchEntity[T]]]: - """ - Same as :py:func:`krrood.entity_query_language.entity.match` but with a domain to use for the variable created - by the match. - - :param type_: The type of the variable (i.e., The class you want to instantiate). - :param domain: The domain used for the variable created by the match. - :return: The factory function for creating the match query. - """ - - def match_factory(**kwargs) -> MatchEntity[T]: - return MatchEntity(type_, kwargs, domain) - - return match_factory diff --git a/src/krrood/entity_query_language/failures.py b/src/krrood/entity_query_language/failures.py index 4cbb95b..c7271c8 100644 --- a/src/krrood/entity_query_language/failures.py +++ b/src/krrood/entity_query_language/failures.py @@ -120,9 +120,7 @@ class UnSupportedOperand(UnsupportedOperation): """ def __post_init__(self): - self.message = ( - f"{self.unsupported_operand} cannot be used as an operand for {self.operation} operations." - ) + self.message = f"{self.unsupported_operand} cannot be used as an operand for {self.operation} operations." super().__post_init__() @@ -191,3 +189,24 @@ def __post_init__(self): f"e.g. Entity, or SetOf" ) super().__post_init__() + + +@dataclass +class ClassDiagramError(DataclassException): + """ + An error related to the class diagram. + """ + + +@dataclass +class NoneWrappedFieldError(ClassDiagramError): + """ + Raised when a field of a class is not wrapped by a WrappedField. + """ + + clazz: Type + attr_name: str + + def __post_init__(self): + self.message = f"Field '{self.attr_name}' of class '{self.clazz.__name__}' is not wrapped by a WrappedField." + super().__post_init__() diff --git a/src/krrood/entity_query_language/match.py b/src/krrood/entity_query_language/match.py new file mode 100644 index 0000000..67088eb --- /dev/null +++ b/src/krrood/entity_query_language/match.py @@ -0,0 +1,465 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import cached_property +from typing import Generic, Optional, Type, Dict, Any, List, Union, Self, Iterable + +from krrood.entity_query_language.symbolic import Exists + +from .entity import ( + ConditionType, + contains, + in_, + flatten, + let, + set_of, + entity, + DomainType, + exists, +) +from .failures import NoneWrappedFieldError +from .hashed_data import T, HashedValue +from .predicate import HasType +from .symbolic import ( + CanBehaveLikeAVariable, + Attribute, + Comparator, + Flatten, + QueryObjectDescriptor, + Selectable, + SymbolicExpression, + OperationResult, + Literal, + SetOf, + Entity, +) +from .utils import is_iterable + + +@dataclass +class Match(Generic[T]): + """ + Construct a query that looks for the pattern provided by the type and the keyword arguments. + Example usage where we look for an object of type Drawer with body of type Body that has the name"drawer_1": + >>> @dataclass + >>> class Body: + >>> name: str + >>> @dataclass + >>> class Drawer: + >>> body: Body + >>> drawer = match(Drawer)(body=match(Body)(name="drawer_1")) + """ + + type_: Optional[Type[T]] = None + """ + The type of the variable. + """ + domain: DomainType = field(default=None, kw_only=True) + """ + The domain to use for the variable created by the match. + """ + kwargs: Dict[str, Any] = field(init=False, default_factory=dict) + """ + The keyword arguments to match against. + """ + variable: Optional[CanBehaveLikeAVariable[T]] = field(kw_only=True, default=None) + """ + The created variable from the type and kwargs. + """ + conditions: List[ConditionType] = field(init=False, default_factory=list) + """ + The conditions that define the match. + """ + selected_variables: List[CanBehaveLikeAVariable] = field( + init=False, default_factory=list + ) + """ + A list of selected attributes. + """ + parent: Optional[Match] = field(init=False, default=None) + """ + The parent match if this is a nested match. + """ + is_selected: bool = field(default=False, kw_only=True) + """ + Whether the variable should be selected in the result. + """ + existential: bool = field(default=False, kw_only=True) + """ + Whether the match is an existential match check or not. + """ + universal: bool = field(default=False, kw_only=True) + """ + Whether the match is a universal match (i.e., must match for all values of the variable/attribute) check or not. + """ + + def __call__(self, **kwargs) -> Union[Self, T, CanBehaveLikeAVariable[T]]: + """ + Update the match with new keyword arguments to constrain the type we are matching with. + + :param kwargs: The keyword arguments to match against. + :return: The current match instance after updating it with the new keyword arguments. + """ + self.kwargs = kwargs + return self + + def _resolve( + self, + variable: Optional[CanBehaveLikeAVariable] = None, + parent: Optional[Match] = None, + ): + """ + Resolve the match by creating the variable and conditions expressions. + + :param variable: An optional pre-existing variable to use for the match; if not provided, a new variable will + be created. + :param parent: The parent match if this is a nested match. + :return: + """ + self._update_fields(variable, parent) + for attr_name, attr_assigned_value in self.kwargs.items(): + attr_assignment = AttributeAssignment( + attr_name, self.variable, attr_assigned_value + ) + if isinstance(attr_assigned_value, Select): + self._update_selected_variables(attr_assignment.attr) + attr_assigned_value._var_ = attr_assignment.attr + if attr_assignment.is_an_unresolved_match: + attr_assignment.resolve(self) + self.conditions.extend(attr_assignment.conditions) + else: + condition = ( + attr_assignment.infer_condition_between_attribute_and_assigned_value() + ) + self.conditions.append(condition) + + def _update_fields( + self, + variable: Optional[CanBehaveLikeAVariable] = None, + parent: Optional[Match] = None, + ): + """ + Update the match variable, parent, is_selected, and type_ fields. + + :param variable: The variable to use for the match. + If None, a new variable will be created. + :param parent: The parent match if this is a nested match. + """ + + if variable is not None: + self.variable = variable + elif self.variable is None: + self.variable = let(self.type_, self.domain) + + self.parent = parent + + if self.is_selected: + self._update_selected_variables(self.variable) + + if not self.type_: + self.type_ = self.variable._type_ + + def _update_selected_variables(self, variable: CanBehaveLikeAVariable): + """ + Update the selected variables of the match by adding the given variable to the root Match selected variables. + """ + if self.parent: + self.parent._update_selected_variables(variable) + elif hash(variable) not in map(hash, self.selected_variables): + self.selected_variables.append(variable) + + @cached_property + def expression(self) -> QueryObjectDescriptor[T]: + """ + Return the entity expression corresponding to the match query. + """ + self._resolve() + if len(self.selected_variables) > 1: + return set_of(self.selected_variables, *self.conditions) + else: + if not self.selected_variables: + self.selected_variables.append(self.variable) + return entity(self.selected_variables[0], *self.conditions) + + +@dataclass +class AttributeAssignment: + """ + A class representing an attribute assignment in a Match statement. + """ + + attr_name: str + """ + The name of the attribute to assign the value to. + """ + variable: CanBehaveLikeAVariable + """ + The variable whose attribute is being assigned. + """ + assigned_value: Union[Literal, Match] + """ + The value to assign to the attribute, which can be a Match instance or a Literal. + """ + conditions: List[ConditionType] = field(init=False, default_factory=list) + """ + The conditions that define attribute assignment. + """ + + def resolve(self, parent_match: Match): + """ + Resolve the attribute assignment by creating the conditions and applying the necessary mappings + to the attribute. + + :param parent_match: The parent match of the attribute assignment. + """ + possibly_flattened_attr = self.attr + if self.attr._is_iterable_ and ( + self.assigned_value.kwargs or self.is_type_filter_needed + ): + possibly_flattened_attr = flatten(self.attr) + + self.assigned_value._resolve(possibly_flattened_attr, parent_match) + + if self.is_type_filter_needed: + self.conditions.append( + HasType(possibly_flattened_attr, self.assigned_value.type_) + ) + + self.conditions.extend(self.assigned_value.conditions) + + def infer_condition_between_attribute_and_assigned_value( + self, + ) -> Union[Comparator, Exists]: + """ + Find and return the appropriate condition for the attribute and its assigned value. This can be one of contains, + in_, or == depending on the type of the assigned value and the type of the attribute. In addition, if the + assigned value is a Match instance with an existential flag set, an Exists expression is created over the + comparator condition. + + :return: A Comparator or an Exists expression representing the condition. + """ + if self.attr._is_iterable_ and not self.is_iterable_value: + condition = contains(self.attr, self.assigned_variable) + elif not self.attr._is_iterable_ and self.is_iterable_value: + condition = in_(self.attr, self.assigned_variable) + elif ( + self.attr._is_iterable_ + and self.is_iterable_value + and not ( + isinstance(self.assigned_value, Match) and self.assigned_value.universal + ) + ): + condition = contains(self.assigned_variable, flatten(self.attr)) + else: + condition = self.attr == self.assigned_variable + + if isinstance(self.assigned_value, Match) and self.assigned_value.existential: + condition = exists(self.attr, condition) + + return condition + + @cached_property + def assigned_variable(self) -> CanBehaveLikeAVariable: + """ + :return: The symbolic variable representing the assigned value. + """ + return ( + self.assigned_value.variable + if isinstance(self.assigned_value, Match) + else self.assigned_value + ) + + @cached_property + def attr(self) -> Attribute: + """ + :return: the attribute of the variable. + :raises NoneWrappedFieldError: If the attribute does not have a WrappedField. + """ + attr: Attribute = getattr(self.variable, self.attr_name) + if not attr._wrapped_field_: + raise NoneWrappedFieldError(self.variable._type_, self.attr_name) + return attr + + @property + def is_an_unresolved_match(self) -> bool: + """ + :return: True if the value is an unresolved Match instance, else False. + """ + return ( + isinstance(self.assigned_value, Match) and not self.assigned_value.variable + ) + + @cached_property + def is_iterable_value(self) -> bool: + """ + :return: True if the value is an iterable or a Match instance with an iterable type, else False. + """ + if isinstance(self.assigned_value, CanBehaveLikeAVariable): + return self.assigned_value._is_iterable_ + elif not isinstance(self.assigned_value, Match) and is_iterable( + self.assigned_value + ): + return True + elif ( + isinstance(self.assigned_value, Match) + and self.assigned_value.variable._is_iterable_ + ): + return True + return False + + @cached_property + def is_type_filter_needed(self): + """ + :return: True if a type filter condition is needed for the attribute assignment, else False. + """ + attr_type = self.attr._type_ + return (not attr_type) or ( + (self.assigned_value.type_ and self.assigned_value.type_ is not attr_type) + and issubclass(self.assigned_value.type_, attr_type) + ) + + +@dataclass +class Select(Match[T], Selectable[T]): + """ + This is a Match with the addition that the matched entity is selected in the result. + """ + + _var_: CanBehaveLikeAVariable[T] = field(init=False) + is_selected: bool = field(init=False, default=True) + + def __post_init__(self): + """ + This is needed to prevent the SymbolicExpression __post_init__ from being called which will make a node out of + this instance, and that is not what we want. + """ + ... + + def _resolve( + self, + variable: Optional[CanBehaveLikeAVariable] = None, + parent: Optional[Match] = None, + ): + super()._resolve(variable, parent) + variable = variable or self.variable + if not self._var_: + self._var_ = variable + + def _evaluate__( + self, + sources: Optional[Dict[int, HashedValue]] = None, + parent: Optional[SymbolicExpression] = None, + ) -> Iterable[OperationResult]: + yield from self.variable._evaluate__(sources, parent) + + @property + def _name_(self) -> str: + return self._var_._name_ + + @cached_property + def _all_variable_instances_(self) -> List[CanBehaveLikeAVariable[T]]: + return self._var_._all_variable_instances_ + + +def match( + type_: Union[Type[T], CanBehaveLikeAVariable[T], Any, None] = None, +) -> Union[Type[T], CanBehaveLikeAVariable[T], Match[T]]: + """ + Create and return a Match instance that looks for the pattern provided by the type and the + keyword arguments. + + :param type_: The type of the variable (i.e., The class you want to instantiate). + :return: The Match instance. + """ + return entity_matching(type_, None) + + +def match_any( + type_: Union[Type[T], CanBehaveLikeAVariable[T], Any, None] = None, +) -> Union[Type[T], CanBehaveLikeAVariable[T], Match[T]]: + """ + Equivalent to match(type_) but for existential checks. + """ + match_ = match(type_) + match_.existential = True + return match_ + + +def match_all( + type_: Union[Type[T], CanBehaveLikeAVariable[T], Any, None] = None, +) -> Union[Type[T], CanBehaveLikeAVariable[T], Match[T]]: + """ + Equivalent to match(type_) but for universal checks. + """ + match_ = match(type_) + match_.universal = True + return match_ + + +def select( + type_: Union[Type[T], CanBehaveLikeAVariable[T], Any, None] = None, +) -> Union[Type[T], CanBehaveLikeAVariable[T], Select[T]]: + """ + Equivalent to match(type_) and selecting the variable to be included in the result. + """ + return entity_selection(type_, None) + + +def select_any( + type_: Union[Type[T], CanBehaveLikeAVariable[T], Any, None] = None, +) -> Union[Type[T], CanBehaveLikeAVariable[T], Select[T]]: + """ + Equivalent to select(type_) but for existential checks. + """ + select_ = select(type_) + select_.existential = True + return select_ + + +def select_all( + type_: Union[Type[T], CanBehaveLikeAVariable[T], Any, None] = None, +) -> Union[Type[T], CanBehaveLikeAVariable[T], Select[T]]: + """ + Equivalent to select(type_) but for universal checks. + """ + select_ = select(type_) + select_.universal = True + return select_ + + +def entity_matching( + type_: Union[Type[T], CanBehaveLikeAVariable[T]], domain: DomainType +) -> Union[Type[T], CanBehaveLikeAVariable[T], Match[T]]: + """ + Same as :py:func:`krrood.entity_query_language.match.match` but with a domain to use for the variable created + by the match. + + :param type_: The type of the variable (i.e., The class you want to instantiate). + :param domain: The domain used for the variable created by the match. + :return: The MatchEntity instance. + """ + if isinstance(type_, CanBehaveLikeAVariable): + return Match(type_._type_, domain=domain, variable=type_) + elif type_ and not isinstance(type_, type): + return Match(type_, domain=domain, variable=Literal(type_)) + return Match(type_, domain=domain) + + +def entity_selection( + type_: Union[Type[T], CanBehaveLikeAVariable[T]], domain: DomainType +) -> Union[Type[T], CanBehaveLikeAVariable[T], Select[T]]: + """ + Same as :py:func:`krrood.entity_query_language.match.entity_matching` but also selecting the variable to be + included in the result. + """ + if isinstance(type_, CanBehaveLikeAVariable): + return Select(type_._type_, domain=domain, variable=type_) + elif type_ and not isinstance(type_, type): + return Select(type_, domain=domain, variable=Literal(type_)) + return Select(type_, domain=domain) + + +EntityType = Union[SetOf[T], Entity[T], T, Iterable[T], Type[T], Match[T]] +""" +The possible types for entities. +""" diff --git a/src/krrood/entity_query_language/predicate.py b/src/krrood/entity_query_language/predicate.py index 7a70c3e..70feafb 100644 --- a/src/krrood/entity_query_language/predicate.py +++ b/src/krrood/entity_query_language/predicate.py @@ -68,7 +68,7 @@ def wrapper(*args, **kwargs) -> Optional[Any]: @dataclass(eq=False) class Symbol: - """Base class for things that can be described by property descriptors.""" + """Base class for things that can be cached in the symbol graph.""" def __new__(cls, *args, **kwargs): instance = super().__new__(cls) @@ -145,49 +145,6 @@ class HasTypes(HasType): """ -def extract_selected_variable_and_expression( - symbolic_cls: Type, - domain: Optional[From] = None, - predicate_type: Optional[PredicateType] = None, - **kwargs, -): - """ - Extracts a variable and constructs its expression tree for the given symbolic class. - - This function generates a variable of the specified `symbolic_cls` and uses the - provided domain, predicate type, and additional arguments to create its expression - tree. The domain can optionally be filtered when iterating through its elements - if specified or retrieved from the cache keys associated with the symbolic class. - - :param symbolic_cls: The symbolic class type to be used for variable creation. - :param domain: Optional domain to provide constraints for the variable. - :param predicate_type: Optional predicate type associated with the variable. - :param kwargs: Additional properties to define and construct the variable. - :return: A tuple containing the generated variable and its corresponding expression tree. - """ - cache_keys = [symbolic_cls] + recursive_subclasses(symbolic_cls) - if not domain and cache_keys: - domain = From( - ( - instance - for instance in SymbolGraph()._class_to_wrapped_instances[symbolic_cls] - ) - ) - elif domain and is_iterable(domain.domain): - domain.domain = filter(lambda v: isinstance(v, symbolic_cls), domain.domain) - - var = Variable( - _name__=symbolic_cls.__name__, - _type_=symbolic_cls, - _domain_source_=domain, - _predicate_type_=predicate_type, - ) - - expression = properties_to_expression_tree(var, kwargs) - - return var, expression - - def update_cache(instance: Symbol): """ Updates the cache with the given instance of a symbolic type. diff --git a/src/krrood/entity_query_language/quantify_entity.py b/src/krrood/entity_query_language/quantify_entity.py new file mode 100644 index 0000000..a464848 --- /dev/null +++ b/src/krrood/entity_query_language/quantify_entity.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Optional, Union, Type + +from .hashed_data import T +from .match import EntityType, Match +from .result_quantification_constraint import ( + ResultQuantificationConstraint, +) +from .symbolic import An, The, ResultQuantifier + + +def an( + entity_: EntityType, + quantification: Optional[ResultQuantificationConstraint] = None, +) -> Union[An[T], T]: + """ + Select a single element satisfying the given entity description. + + :param entity_: An entity or a set expression to quantify over. + :param quantification: Optional quantification constraint. + :return: A quantifier representing "an" element. + :rtype: An[T] + """ + return _quantify_entity(An, entity_, _quantification_constraint_=quantification) + + +a = an +""" +This is an alias to accommodate for words not starting with vowels. +""" + + +def the( + entity_: EntityType, +) -> Union[The[T], T]: + """ + Select the unique element satisfying the given entity description. + + :param entity_: An entity or a set expression to quantify over. + :return: A quantifier representing "an" element. + :rtype: The[T] + """ + return _quantify_entity(The, entity_) + + +def _quantify_entity( + quantifier: Type[ResultQuantifier], entity_: EntityType, **quantifier_kwargs +) -> Union[ResultQuantifier[T], T]: + """ + Apply the given quantifier to the given entity. + + :param quantifier: The quantifier to apply. + :param entity_: The entity to quantify. + :param quantifier_kwargs: Keyword arguments to pass to the quantifier. + :return: The quantified entity. + """ + if isinstance(entity_, Match) and not entity_.variable: + entity_ = entity_.expression + return quantifier(entity_, **quantifier_kwargs) diff --git a/src/krrood/entity_query_language/result_quantification_constraint.py b/src/krrood/entity_query_language/result_quantification_constraint.py index 3f4f865..59fc9f4 100644 --- a/src/krrood/entity_query_language/result_quantification_constraint.py +++ b/src/krrood/entity_query_language/result_quantification_constraint.py @@ -12,7 +12,7 @@ ) if TYPE_CHECKING: - from .symbolic import An, ResultQuantifier + from .symbolic import ResultQuantifier @dataclass diff --git a/src/krrood/entity_query_language/symbol_graph.py b/src/krrood/entity_query_language/symbol_graph.py index 4f71eba..211e39f 100644 --- a/src/krrood/entity_query_language/symbol_graph.py +++ b/src/krrood/entity_query_language/symbol_graph.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import copy import os import weakref from collections import defaultdict @@ -228,7 +229,7 @@ def remove_node(self, wrapped_instance: WrappedInstance): """ self._instance_index.pop(id(wrapped_instance.instance), None) self._class_to_wrapped_instances[wrapped_instance.instance_type].remove( - wrapped_instance, + wrapped_instance ) self._instance_graph.remove_node(wrapped_instance.index) diff --git a/src/krrood/entity_query_language/symbolic.py b/src/krrood/entity_query_language/symbolic.py index 8d947a7..aac3e55 100644 --- a/src/krrood/entity_query_language/symbolic.py +++ b/src/krrood/entity_query_language/symbolic.py @@ -52,7 +52,7 @@ ) from .rxnode import RWXNode, ColorLegend from .symbol_graph import SymbolGraph -from .utils import IDGenerator, is_iterable, generate_combinations +from .utils import IDGenerator, is_iterable, generate_combinations, make_list, make_set from ..class_diagrams import ClassRelation from ..class_diagrams.class_diagram import Association, WrappedClass from ..class_diagrams.failures import ClassIsUnMappedInClassDiagram @@ -111,6 +111,10 @@ class OperationResult: def is_true(self): return not self.is_false + @property + def value(self) -> Optional[HashedValue]: + return self.bindings.get(self.operand._id_, None) + def __contains__(self, item): return item in self.bindings @@ -374,13 +378,9 @@ def __repr__(self): @dataclass(eq=False, repr=False) -class CanBehaveLikeAVariable(SymbolicExpression[T], ABC): - """ - This class adds the monitoring/tracking behaviour on variables that tracks attribute access, calling, - and comparison operations. - """ +class Selectable(SymbolicExpression[T], ABC): - _var_: CanBehaveLikeAVariable[T] = field(init=False, default=None) + _var_: CanBehaveLikeAVariable = field(init=False, default=None) """ A variable that is used if the child class to this class want to provide a variable to be tracked other than itself, this is specially useful for child classes that holds a variable instead of being a variable and want @@ -388,12 +388,32 @@ class CanBehaveLikeAVariable(SymbolicExpression[T], ABC): For example, this is the case for the ResultQuantifiers & QueryDescriptors that operate on a single selected variable. """ + + @property + def _is_iterable_(self): + """ + Whether the selectable is iterable. + + :return: True if the selectable is iterable, False otherwise. + """ + if self._var_ and self._var_ is not self: + return self._var_._is_iterable_ + return False + + +@dataclass(eq=False, repr=False) +class CanBehaveLikeAVariable(Selectable[T], ABC): + """ + This class adds the monitoring/tracking behaviour on variables that tracks attribute access, calling, + and comparison operations. + """ + _path_: List[ClassRelation] = field(init=False, default_factory=list) """ The path of the variable in the symbol graph as a sequence of relation instances. """ - _type_: Type = field(init=False, default=None) + _type_: Type[T] = field(init=False, default=None) """ The type of the variable. """ @@ -956,6 +976,11 @@ def _evaluate__( self._eval_parent_ = parent sources = sources or {} if self._id_ in sources: + if ( + isinstance(self._parent_, LogicalBinaryOperator) + or self is self._conditions_root_ + ): + self._is_false_ = not bool(sources[self._id_]) yield OperationResult(sources, not bool(sources[self._id_]), self) elif self._domain_: for v in self._domain_: @@ -1018,6 +1043,10 @@ def _all_variable_instances_(self) -> List[Variable]: variables.extend(v._all_variable_instances_) return variables + @property + def _is_iterable_(self): + return bool(self._domain_) + @property def _plot_color_(self) -> ColorLegend: if self._plot_color__: @@ -1042,10 +1071,9 @@ def __init__( ): original_data = data data = [data] - if not is_iterable(data): - data = HashedIterable([data]) if not type_: - first_value = next(iter(data), None) + original_data_lst = make_list(original_data) + first_value = original_data_lst[0] if len(original_data_lst) > 0 else None type_ = type(first_value) if first_value else None if name is None: if type_: @@ -1177,6 +1205,12 @@ def _relation_(self): ) return None + @property + def _is_iterable_(self): + if not self._wrapped_field_: + return False + return self._wrapped_field_.is_iterable + @cached_property def _wrapped_type_(self): try: @@ -1215,6 +1249,8 @@ def _type_(self) -> Optional[Type]: @cached_property def _wrapped_field_(self) -> Optional[WrappedField]: + if self._wrapped_owner_class_ is None: + return None return self._wrapped_owner_class_._wrapped_field_name_map_.get( self._attr_name_, None ) @@ -1230,7 +1266,7 @@ def _wrapped_owner_class_(self): return None def _apply_mapping_(self, value: HashedValue) -> Iterable[HashedValue]: - yield HashedValue(id_=value.id_, value=getattr(value.value, self._attr_name_)) + yield HashedValue(getattr(value.value, self._attr_name_)) @property def _name_(self): @@ -1299,6 +1335,13 @@ def _apply_mapping_(self, value: HashedValue) -> Iterable[HashedValue]: def _name_(self): return f"Flatten({self._child_._name_})" + @property + def _is_iterable_(self): + """ + :return: False as Flatten does not preserve the original iterable structure. + """ + return False + @dataclass(eq=False, repr=False) class BinaryOperator(SymbolicExpression, ABC): @@ -1407,9 +1450,18 @@ def _evaluate__( ) def apply_operation(self, operand_values: OperationResult) -> bool: - res = self.operation( - operand_values[self.left._id_].value, operand_values[self.right._id_].value + left_value, right_value = ( + operand_values.bindings[self.left._id_], + operand_values.bindings[self.right._id_], ) + if ( + self.operation in [operator.eq, operator.ne] + and is_iterable(left_value.value) + and is_iterable(right_value.value) + ): + left_value = HashedValue(make_set(left_value.value)) + right_value = HashedValue(make_set(right_value.value)) + res = self.operation(left_value.value, right_value.value) self._is_false_ = not res operand_values[self._id_] = HashedValue(res) return res @@ -1417,6 +1469,12 @@ def apply_operation(self, operand_values: OperationResult) -> bool: def get_first_second_operands( self, sources: Dict[int, HashedValue] ) -> Tuple[SymbolicExpression, SymbolicExpression]: + left_has_the = any(isinstance(desc, The) for desc in self.left._descendants_) + right_has_the = any(isinstance(desc, The) for desc in self.right._descendants_) + if left_has_the and not right_has_the: + return self.left, self.right + elif not left_has_the and right_has_the: + return self.right, self.left if sources and any( v.value._var_._id_ in sources for v in self.right._unique_variables_ ): @@ -1730,18 +1788,12 @@ def _evaluate__( ) -> Iterable[OperationResult]: sources = sources or {} self._eval_parent_ = parent - for var_val in self.variable._evaluate__(sources, parent=self): - yield from self.evaluate_condition(var_val.bindings) - - def evaluate_condition( - self, sources: Dict[int, HashedValue] - ) -> Iterable[OperationResult]: - # Evaluate the condition under this particular universal value - for condition_val in self.condition._evaluate__(sources, parent=self): - self._is_false_ = condition_val.is_false - if not self._is_false_: - yield OperationResult(condition_val.bindings, False, self) - break + seen_var_values = [] + for val in self.condition._evaluate__(sources, parent=self): + var_val = val[self.variable._id_] + if val.is_true and var_val.value not in seen_var_values: + seen_var_values.append(var_val.value) + yield OperationResult(val.bindings, False, self) def _invert_(self): return ForAll(self.variable, self.condition._invert_()) diff --git a/src/krrood/entity_query_language/utils.py b/src/krrood/entity_query_language/utils.py index a5f72c3..acd7e8b 100644 --- a/src/krrood/entity_query_language/utils.py +++ b/src/krrood/entity_query_language/utils.py @@ -18,7 +18,7 @@ except ImportError: Source = None -from typing_extensions import Set, Any, List +from typing_extensions import Set, Any, List, Type class IDGenerator: diff --git a/test/test_eql/test_aggregations.py b/test/test_eql/test_aggregations.py index a4ae0ff..2cc9bd6 100644 --- a/test/test_eql/test_aggregations.py +++ b/test/test_eql/test_aggregations.py @@ -1,14 +1,13 @@ from krrood.entity_query_language.entity import ( flatten, entity, - an, not_, in_, - the, for_all, let, exists, ) +from krrood.entity_query_language.quantify_entity import an, the from ..dataset.example_classes import VectorsWithProperty from ..dataset.semantic_world_like_classes import View, Drawer, Container, Cabinet diff --git a/test/test_eql/test_core/test_queries.py b/test/test_eql/test_core/test_queries.py index b842d5f..ff59553 100644 --- a/test/test_eql/test_core/test_queries.py +++ b/test/test_eql/test_core/test_queries.py @@ -8,18 +8,14 @@ not_, contains, in_, - an, entity, set_of, let, - the, or_, - a, exists, flatten, - match, - entity_matching, ) +from krrood.entity_query_language.quantify_entity import an, a, the from krrood.entity_query_language.failures import ( MultipleSolutionFound, UnsupportedNegation, @@ -31,7 +27,6 @@ symbolic_function, Predicate, ) -from krrood.entity_query_language.symbol_graph import SymbolGraph from krrood.entity_query_language.result_quantification_constraint import ( ResultQuantificationConstraint, Exactly, @@ -748,34 +743,3 @@ def get_quantified_query(quantification: ResultQuantificationConstraint): list(get_quantified_query(Exactly(2)).evaluate()) with pytest.raises(LessThanExpectedNumberOfSolutions): list(get_quantified_query(Exactly(4)).evaluate()) - - -def test_match(handles_and_containers_world): - world = handles_and_containers_world - - fixed_connection_query = the( - entity_matching(FixedConnection, world.connections)( - parent=match(Container)(name="Container1"), - child=match(Handle)(name="Handle1"), - ) - ) - - fixed_connection_query_manual = the( - entity( - fc := let(FixedConnection, domain=None), - HasType(fc.parent, Container), - HasType(fc.child, Handle), - fc.parent.name == "Container1", - fc.child.name == "Handle1", - ) - ) - - assert fixed_connection_query == fixed_connection_query_manual - - fixed_connection_query.visualize() - - fixed_connection = fixed_connection_query.evaluate() - assert isinstance(fixed_connection, FixedConnection) - assert fixed_connection.parent.name == "Container1" - assert isinstance(fixed_connection.child, Handle) - assert fixed_connection.child.name == "Handle1" diff --git a/test/test_eql/test_core/test_rules.py b/test/test_eql/test_core/test_rules.py index d1941da..142b6a7 100644 --- a/test/test_eql/test_core/test_rules.py +++ b/test/test_eql/test_core/test_rules.py @@ -1,5 +1,6 @@ from krrood.entity_query_language.conclusion import Add -from krrood.entity_query_language.entity import let, an, entity, and_, inference +from krrood.entity_query_language.entity import let, entity, and_, inference +from krrood.entity_query_language.quantify_entity import an from krrood.entity_query_language.predicate import HasType from krrood.entity_query_language.rule import refinement, alternative, next_rule from ...dataset.semantic_world_like_classes import ( diff --git a/test/test_eql/test_indexing.py b/test/test_eql/test_indexing.py index 1891f11..5714aa9 100644 --- a/test/test_eql/test_indexing.py +++ b/test/test_eql/test_indexing.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from typing_extensions import Dict, List -from krrood.entity_query_language.entity import an, entity, let, From +from krrood.entity_query_language.entity import entity, let, From +from krrood.entity_query_language.quantify_entity import an from krrood.entity_query_language.predicate import Symbol from krrood.entity_query_language.symbol_graph import SymbolGraph diff --git a/test/test_eql/test_match.py b/test/test_eql/test_match.py new file mode 100644 index 0000000..43d4da9 --- /dev/null +++ b/test/test_eql/test_match.py @@ -0,0 +1,119 @@ +import pytest + +from krrood.entity_query_language.entity import ( + entity, + let, +) +from krrood.entity_query_language.quantify_entity import an, the +from krrood.entity_query_language.match import ( + match, + match_any, + select, + entity_matching, + match_all, +) +from krrood.entity_query_language.predicate import HasType +from krrood.entity_query_language.symbolic import UnificationDict, SetOf +from ..dataset.semantic_world_like_classes import ( + FixedConnection, + Container, + Handle, + Cabinet, + Drawer, +) + + +def test_match(handles_and_containers_world): + world = handles_and_containers_world + + fixed_connection_query = the( + entity_matching(FixedConnection, world.connections)( + parent=match(Container)(name="Container1"), + child=match(Handle)(name="Handle1"), + ) + ) + + fixed_connection_query_manual = the( + entity( + fc := let(FixedConnection, domain=None), + HasType(fc.parent, Container), + HasType(fc.child, Handle), + fc.parent.name == "Container1", + fc.child.name == "Handle1", + ) + ) + + assert fixed_connection_query == fixed_connection_query_manual + + fixed_connection = fixed_connection_query.evaluate() + assert isinstance(fixed_connection, FixedConnection) + assert fixed_connection.parent.name == "Container1" + assert isinstance(fixed_connection.child, Handle) + assert fixed_connection.child.name == "Handle1" + + +def test_select(handles_and_containers_world): + world = handles_and_containers_world + container, handle = select(Container), select(Handle) + fixed_connection_query = the( + entity_matching(FixedConnection, world.connections)( + parent=container(name="Container1"), + child=handle(name="Handle1"), + ) + ) + + assert isinstance(fixed_connection_query._child_, SetOf) + + answers = fixed_connection_query.evaluate() + assert isinstance(answers, UnificationDict) + assert answers[container].name == "Container1" + assert answers[handle].name == "Handle1" + + +@pytest.fixture +def world_and_cabinets_and_specific_drawer(handles_and_containers_world): + world = handles_and_containers_world + my_drawer = Drawer(handle=Handle("Handle2"), container=Container("Container1")) + drawers = list(filter(lambda v: isinstance(v, Drawer), world.views)) + my_cabinet_1 = Cabinet( + container=Container("container2"), drawers=[my_drawer] + drawers + ) + my_cabinet_2 = Cabinet(container=Container("container2"), drawers=[my_drawer]) + my_cabinet_3 = Cabinet(container=Container("container2"), drawers=drawers) + return world, [my_cabinet_1, my_cabinet_2, my_cabinet_3], my_drawer + + +def test_match_any(world_and_cabinets_and_specific_drawer): + world, cabinets, my_drawer = world_and_cabinets_and_specific_drawer + cabinet = an(entity_matching(Cabinet, cabinets)(drawers=match_any([my_drawer]))) + found_cabinets = list(cabinet.evaluate()) + assert len(found_cabinets) == 2 + assert cabinets[0] in found_cabinets + assert cabinets[1] in found_cabinets + + +def test_match_all(world_and_cabinets_and_specific_drawer): + world, cabinets, my_drawer = world_and_cabinets_and_specific_drawer + cabinet = the(entity_matching(Cabinet, cabinets)(drawers=match_all([my_drawer]))) + found_cabinet = cabinet.evaluate() + assert found_cabinet is cabinets[1] + + +def test_match_any_on_collection_returns_unique_parent_entities(): + # setup from the notebook example + c1 = Container("Container1") + other_c = Container("ContainerX") + h1 = Handle("Handle1") + + drawer1 = Drawer(handle=h1, container=c1) + drawer2 = Drawer(handle=Handle("OtherHandle"), container=other_c) + cabinet1 = Cabinet(container=c1, drawers=[drawer1, drawer2]) + cabinet2 = Cabinet(container=other_c, drawers=[drawer2]) + views = [drawer1, drawer2, cabinet1, cabinet2] + + q = an(entity_matching(Cabinet, views)(drawers=match_any([drawer1, drawer2]))) + + results = list(q.evaluate()) + # Expect exactly the two cabinets, no duplicates + assert len(results) == 2 + assert {id(x) for x in results} == {id(cabinet1), id(cabinet2)} diff --git a/test/test_eql/test_rendering.py b/test/test_eql/test_rendering.py index 878aea3..9795651 100644 --- a/test/test_eql/test_rendering.py +++ b/test/test_eql/test_rendering.py @@ -22,10 +22,10 @@ from krrood.entity_query_language.entity import ( entity, let, - an, inference, and_, ) +from krrood.entity_query_language.quantify_entity import an from krrood.entity_query_language.conclusion import Add from krrood.entity_query_language.predicate import HasType diff --git a/test/test_eql/test_symbol_graph.py b/test/test_eql/test_symbol_graph.py index 70b9f9e..e3e008e 100644 --- a/test/test_eql/test_symbol_graph.py +++ b/test/test_eql/test_symbol_graph.py @@ -2,7 +2,8 @@ import pytest -from krrood.entity_query_language.entity import an, entity, let +from krrood.entity_query_language.entity import entity, let +from krrood.entity_query_language.quantify_entity import an from krrood.entity_query_language.symbol_graph import SymbolGraph from ..dataset.example_classes import Position diff --git a/test/test_ormatic/test_eql.py b/test/test_ormatic/test_eql.py index c742327..7e8856b 100644 --- a/test/test_ormatic/test_eql.py +++ b/test/test_ormatic/test_eql.py @@ -22,14 +22,13 @@ ) from krrood.entity_query_language.entity import ( let, - an, entity, - the, contains, and_, or_, in_, ) +from krrood.entity_query_language.quantify_entity import an, the from krrood.ormatic.dao import to_dao from krrood.ormatic.eql_interface import eql_to_sql