Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/azul/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def hot_entity_types(self) -> Iterable[str]:

@property
def facets(self) -> Sequence[str]:
return [self.special_fields.source_id.name]
return []

@property
@abstractmethod
Expand Down
65 changes: 44 additions & 21 deletions src/azul/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,25 @@ def to_json(self) -> JSON:
def update(self, filters: FiltersJSON) -> Self:
return attr.evolve(self, explicit={**self.explicit, **filters})

def reify_implicit_only(self, plugin: MetadataPlugin) -> FiltersJSON:
"""
Construct filters for *only* the implicit restriction on accessible
sources. This is conceptually equivalent to the set difference
`self.reify(limit_access=True) - self.reify(limit_access=False)`.

:param plugin: Metadata plugin for the current request's catalog
"""
source_id_field = plugin.special_fields.source_id.name
filters = copy_json(self.explicit)
explicit_sources = self._extract_filter(filters, source_id_field, default=None)
if explicit_sources is not None:
self._forbid_explicit_inaccessible_sources(explicit_sources)
return {
source_id_field: {
'is': sorted(self.source_ids),
}
}

def reify(self,
plugin: MetadataPlugin,
*,
Expand All @@ -361,33 +380,20 @@ def reify(self,
filters = copy_json(self.explicit)
special_fields = plugin.special_fields

def extract_filter(field: str, *, default: set | None) -> set | None:
filter = filters.pop(field, {})
# Other operators are not supported on string or boolean fields
assert filter.keys() <= {'is'}, filter
try:
values = filter['is']
except KeyError:
return default
else:
return set(values)

explicit_sources = extract_filter(special_fields.source_id.name,
default=None)
accessible = extract_filter(special_fields.accessible.name,
default={False, True})
explicit_sources = self._extract_filter(filters,
special_fields.source_id.name,
default=None)
accessible = self._extract_filter(filters,
special_fields.accessible.name,
default={False, True})
source_relation = 'is'

if limit_access:
if explicit_sources is None:
sources = self.source_ids if True in accessible else []
else:
forbidden_sources = explicit_sources - self.source_ids
if forbidden_sources:
raise ForbiddenError('Cannot filter by inaccessible sources',
forbidden_sources)
else:
sources = explicit_sources if True in accessible else []
self._forbid_explicit_inaccessible_sources(explicit_sources)
sources = explicit_sources if True in accessible else []
else:
if accessible == set():
sources = []
Expand Down Expand Up @@ -417,6 +423,23 @@ def extract_filter(field: str, *, default: set | None) -> set | None:

return filters

def _extract_filter(self, filters, field: str, *, default: set | None) -> set | None:
filter = filters.pop(field, {})
# Other operators are not supported on string or boolean fields
assert filter.keys() <= {'is'}, filter
try:
values = filter['is']
except KeyError:
return default
else:
return set(values)

def _forbid_explicit_inaccessible_sources(self, explicit_sources: set[str]):
forbidden_sources = explicit_sources - self.source_ids
if forbidden_sources:
raise ForbiddenError('Cannot filter by inaccessible sources',
forbidden_sources)


class BadArgumentException(Exception):

Expand Down
24 changes: 19 additions & 5 deletions src/azul/service/elasticsearch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,11 @@ class FilterStage(_ElasticsearchStage[Response, Response]):
post_filter: bool

def prepare_request(self, request: Search) -> Search:
query = self.prepare_query()
query = self.prepare_query(self.prepared_filters)
if self.post_filter:
if self.service.always_limit_access or self._limit_access:
access_query = self.prepare_query(self.prepared_access_filter)
request = request.query(access_query)
request = request.post_filter(query)
else:
request = request.query(query)
Expand All @@ -191,10 +194,17 @@ def process_response(self, response: Response) -> Response:

@cached_property
def prepared_filters(self) -> TranslatedFilters:
limit_access = self.service.always_limit_access or self._limit_access
# The implicit source filter is always applied via a query, and would
# therefore be redundant in the post_filter
limit_access = (self.service.always_limit_access or self._limit_access) and not self.post_filter
filters_json = self.filters.reify(self.plugin, limit_access=limit_access)
return self._translate_filters(filters_json)

@cached_property
def prepared_access_filter(self) -> TranslatedFilters:
filters_json = self.filters.reify_implicit_only(self.plugin)
return self._translate_filters(filters_json)

@property
@abstractmethod
def _limit_access(self) -> bool:
Expand Down Expand Up @@ -223,12 +233,15 @@ def _translate_filters(self, filters: FiltersJSON) -> TranslatedFilters:
translated_filters[field] = {relation: list(values)}
return translated_filters

def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query:
def prepare_query(self,
prepared_filters: TranslatedFilters,
skip_field_paths: tuple[FieldPath] = ()
) -> Query:
"""
Converts the given filters into an Elasticsearch DSL Query object.
"""
filter_list = []
for field_path, relation_and_values in self.prepared_filters.items():
for field_path, relation_and_values in prepared_filters.items():
if field_path not in skip_field_paths:
relation, values = one(relation_and_values.items())
# Note that `is_not` is only used internally (for filtering by
Expand Down Expand Up @@ -320,7 +333,8 @@ def _prepare_aggregation(self, *, facet: str, facet_path: FieldPath) -> Agg:
"""
# Create a filter agg using a query that represents all filters
# except for the current facet.
query = self.filter_stage.prepare_query(skip_field_paths=(facet_path,))
query = self.filter_stage.prepare_query(self.filter_stage.prepared_filters,
skip_field_paths=(facet_path,))
agg = A('filter', query)

field_type = self.service.field_type(self.catalog, facet_path)
Expand Down
35 changes: 28 additions & 7 deletions test/service/test_request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def test_create_request(self):
Tests creation of a simple request
"""
expected_output = {
'query': {
'bool': {
'must': [
self.sources_filter
]
}
},
'post_filter': {
'bool': {
'must': [
Expand All @@ -121,7 +128,6 @@ def test_create_request(self):
}
}
},
self.sources_filter
]
}
}
Expand Down Expand Up @@ -154,6 +160,13 @@ def test_create_request_complex(self):
Tests creation of a complex request.
"""
expected_output = {
'query': {
'bool': {
'must': [
self.sources_filter
]
}
},
'post_filter': {
'bool': {
'must': [
Expand All @@ -168,7 +181,6 @@ def test_create_request_complex(self):
}
}
},
self.sources_filter
]
}
}
Expand All @@ -186,6 +198,13 @@ def test_create_request_missing_values(self):
Tests creation of a request for facets that do not have a value
"""
expected_output = {
'query': {
'bool': {
'must': [
self.sources_filter
]
}
},
'post_filter': {
'bool': {
'must': [
Expand Down Expand Up @@ -221,7 +240,6 @@ def test_create_request_missing_values(self):
}
}
},
self.sources_filter
]
}
}
Expand All @@ -236,6 +254,13 @@ def test_create_request_terms_and_missing_values(self):
not have a value
"""
expected_output = {
'query': {
'bool': {
'must': [
self.sources_filter
]
}
},
'post_filter': {
'bool': {
'must': [
Expand Down Expand Up @@ -300,7 +325,6 @@ def test_create_request_terms_and_missing_values(self):
}
}
},
self.sources_filter
]
}
}
Expand Down Expand Up @@ -346,9 +370,6 @@ def test_create_aggregate(self):
expected_output = {
'filter': {
'bool': {
'must': [
self.sources_filter
]
}
},
'aggs': {
Expand Down
Loading