Skip to content

Commit 056daa8

Browse files
committed
Support trio with httpx
1 parent 5a8e2a7 commit 056daa8

30 files changed

+197
-181
lines changed

elasticsearch/_async/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import asyncio
18+
import anyio
1919
import logging
2020
from typing import (
2121
Any,
@@ -245,7 +245,7 @@ async def map_actions() -> AsyncIterable[_TYPE_BULK_ACTION_HEADER_AND_BODY]:
245245
]
246246
] = []
247247
if attempt:
248-
await asyncio.sleep(
248+
await anyio.sleep(
249249
min(max_backoff, initial_backoff * 2 ** (attempt - 1))
250250
)
251251

pyproject.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@ keywords = [
4141
]
4242
dynamic = ["version"]
4343
dependencies = [
44-
"elastic-transport>=9.1.0,<10",
44+
# TODO revert before merging/releasing
45+
"elastic-transport @ git+https://github.com/pquentin/elastic-transport-python.git@trio-support",
4546
"python-dateutil",
4647
"typing-extensions",
48+
"anyio",
49+
"sniffio",
4750
]
4851

52+
# TODO revert before merging/releasing
53+
[tool.hatch.metadata]
54+
allow-direct-references = true
55+
4956
[project.optional-dependencies]
5057
async = ["aiohttp>=3,<4"]
5158
requests = ["requests>=2.4.0, !=2.32.2, <3.0.0"]
@@ -56,6 +63,7 @@ vectorstore_mmr = ["numpy>=1", "simsimd>=3"]
5663
dev = [
5764
"requests>=2, <3",
5865
"aiohttp",
66+
"httpx",
5967
"pytest",
6068
"pytest-cov",
6169
"pytest-mock",
@@ -78,6 +86,7 @@ dev = [
7886
"mapbox-vector-tile",
7987
"jinja2",
8088
"tqdm",
89+
"trio",
8190
"mypy",
8291
"pyright",
8392
"types-python-dateutil",

test_elasticsearch/test_async/test_server/conftest.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,29 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import sniffio
1819
import pytest
19-
import pytest_asyncio
2020

2121
import elasticsearch
2222

2323
from ...utils import CA_CERTS, wipe_cluster
2424

25-
pytestmark = pytest.mark.asyncio
2625

2726

28-
@pytest_asyncio.fixture(scope="function")
27+
@pytest.fixture(scope="function")
2928
async def async_client_factory(elasticsearch_url):
30-
31-
if not hasattr(elasticsearch, "AsyncElasticsearch"):
32-
pytest.skip("test requires 'AsyncElasticsearch' and aiohttp to be installed")
33-
29+
print("async!", elasticsearch_url)
30+
kwargs = {}
31+
if sniffio.current_async_library() == "trio":
32+
kwargs["node_class"] = "httpxasync"
3433
# Unfortunately the asyncio client needs to be rebuilt every
3534
# test execution due to how pytest-asyncio manages
3635
# event loops (one per test!)
3736
client = None
3837
try:
39-
client = elasticsearch.AsyncElasticsearch(elasticsearch_url, ca_certs=CA_CERTS)
38+
client = elasticsearch.AsyncElasticsearch(
39+
elasticsearch_url, ca_certs=CA_CERTS, **kwargs
40+
)
4041
yield client
4142
finally:
4243
if client:

test_elasticsearch/test_async/test_server/test_clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pytest
2020

21-
pytestmark = pytest.mark.asyncio
21+
pytestmark = pytest.mark.anyio
2222

2323

2424
@pytest.mark.parametrize("kwargs", [{"body": {"text": "привет"}}, {"text": "привет"}])

test_elasticsearch/test_async/test_server/test_helpers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import asyncio
18+
import anyio
1919
import logging
2020
from datetime import datetime, timedelta, timezone
2121
from unittest.mock import MagicMock, call, patch
2222

2323
import pytest
24-
import pytest_asyncio
2524
from elastic_transport import ApiResponseMeta, ObjectApiResponse
2625

2726
from elasticsearch import helpers
2827
from elasticsearch.exceptions import ApiError
2928
from elasticsearch.helpers import ScanError
3029

31-
pytestmark = [pytest.mark.asyncio]
30+
pytestmark = pytest.mark.anyio
3231

3332

3433
class AsyncMock(MagicMock):
@@ -92,7 +91,7 @@ async def test_all_documents_get_inserted(self, async_client):
9291
async def test_documents_data_types(self, async_client):
9392
async def async_gen():
9493
for x in range(100):
95-
await asyncio.sleep(0)
94+
await anyio.sleep(0)
9695
yield {"answer": x, "_id": x}
9796

9897
def sync_gen():
@@ -491,7 +490,7 @@ def __await__(self):
491490
return self().__await__()
492491

493492

494-
@pytest_asyncio.fixture(scope="function")
493+
@pytest.fixture(scope="function")
495494
async def scan_teardown(async_client):
496495
yield
497496
await async_client.clear_scroll(scroll_id="_all")
@@ -915,7 +914,7 @@ async def test_scan_from_keyword_is_aliased(async_client, scan_kwargs):
915914
assert "from" not in search_mock.call_args[1]
916915

917916

918-
@pytest_asyncio.fixture(scope="function")
917+
@pytest.fixture(scope="function")
919918
async def reindex_setup(async_client):
920919
bulk = []
921920
for x in range(100):
@@ -993,7 +992,7 @@ async def test_all_documents_get_moved(self, async_client, reindex_setup):
993992
)["_source"]
994993

995994

996-
@pytest_asyncio.fixture(scope="function")
995+
@pytest.fixture(scope="function")
997996
async def parent_reindex_setup(async_client):
998997
body = {
999998
"settings": {"number_of_shards": 1, "number_of_replicas": 0},
@@ -1054,7 +1053,7 @@ async def test_children_are_reindexed_correctly(
10541053
} == q
10551054

10561055

1057-
@pytest_asyncio.fixture(scope="function")
1056+
@pytest.fixture(scope="function")
10581057
async def reindex_data_stream_setup(async_client):
10591058
dt = datetime.now(tz=timezone.utc)
10601059
bulk = []

test_elasticsearch/test_async/test_server/test_mapbox_vector_tile.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
# under the License.
1717

1818
import pytest
19-
import pytest_asyncio
2019

2120
from elasticsearch import RequestError
2221

23-
pytestmark = pytest.mark.asyncio
22+
pytestmark = pytest.mark.anyio
2423

2524

26-
@pytest_asyncio.fixture(scope="function")
25+
@pytest.fixture(scope="function")
2726
async def mvt_setup(async_client):
2827
await async_client.indices.create(
2928
index="museums",

test_elasticsearch/test_async/test_server/test_rest_api_spec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import warnings
2626

2727
import pytest
28-
import pytest_asyncio
2928

3029
from elasticsearch import ElasticsearchWarning, RequestError
3130

@@ -39,7 +38,7 @@
3938
)
4039
from ...utils import parse_version
4140

42-
pytestmark = pytest.mark.asyncio
41+
pytestmark = pytest.mark.anyio
4342

4443
XPACK_FEATURES = None
4544
ES_VERSION = None
@@ -240,7 +239,7 @@ async def _feature_enabled(self, name):
240239
return name in XPACK_FEATURES
241240

242241

243-
@pytest_asyncio.fixture(scope="function")
242+
@pytest.fixture(scope="function")
244243
def async_runner(async_client_factory):
245244
return AsyncYamlRunner(async_client_factory)
246245

test_elasticsearch/test_async/test_transport.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
# under the License.
1717

1818

19-
import asyncio
19+
import time
20+
import anyio
2021
import re
2122
import warnings
2223
from typing import Any, Dict, Optional
@@ -40,8 +41,6 @@
4041
UnsupportedProductError,
4142
)
4243

43-
pytestmark = pytest.mark.asyncio
44-
4544

4645
class DummyNode(BaseAsyncNode):
4746
def __init__(self, config: NodeConfig):
@@ -175,6 +174,7 @@ def mark_live(self, connection):
175174
}"""
176175

177176

177+
@pytest.mark.anyio
178178
class TestTransport:
179179
async def test_request_timeout_extracted_from_params_and_passed(self):
180180
client = AsyncElasticsearch(
@@ -378,6 +378,9 @@ async def test_override_mark_dead_mark_live(self):
378378
assert len(client.transport.node_pool._alive_nodes) == 2
379379
assert len(client.transport.node_pool._dead_consecutive_failures) == 0
380380

381+
382+
@pytest.mark.asyncio
383+
class TestSniffing:
381384
@pytest.mark.parametrize(
382385
["nodes_info_response", "node_host"],
383386
[(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")],
@@ -528,23 +531,22 @@ async def test_sniff_on_node_failure_triggers(self, extra_key, extra_value):
528531
assert len(client.transport.node_pool) == 3
529532

530533
async def test_sniff_after_n_seconds(self):
531-
event_loop = asyncio.get_running_loop()
532534
client = AsyncElasticsearch( # noqa: F821
533535
[NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})],
534536
node_class=DummyNode,
535537
min_delay_between_sniffing=5,
536538
)
537-
client.transport._last_sniffed_at = event_loop.time()
539+
client.transport._last_sniffed_at = time.monotonic()
538540

539541
await client.info()
540542

541543
for _ in range(4):
542544
await client.info()
543-
await asyncio.sleep(0)
545+
await anyio.sleep(0)
544546

545547
assert 1 == len(client.transport.node_pool)
546548

547-
client.transport._last_sniffed_at = event_loop.time() - 5.1
549+
client.transport._last_sniffed_at = time.monotonic() - 5.1
548550

549551
await client.info()
550552
await client.transport._sniffing_task # Need to wait for the sniffing task to complete
@@ -554,9 +556,9 @@ async def test_sniff_after_n_seconds(self):
554556
node.base_url for node in client.transport.node_pool.all()
555557
)
556558
assert (
557-
event_loop.time() - 1
559+
time.monotonic() - 1
558560
< client.transport._last_sniffed_at
559-
< event_loop.time() + 0.01
561+
< time.monotonic() + 0.01
560562
)
561563

562564
@pytest.mark.parametrize(
@@ -580,8 +582,10 @@ async def test_sniffing_disabled_on_elastic_cloud(self, kwargs):
580582
== "Sniffing should not be enabled when connecting to Elastic Cloud"
581583
)
582584

583-
async def test_sniff_on_start_close_unlocks_async_calls(self):
584-
event_loop = asyncio.get_running_loop()
585+
async def test_sniff_on_start_close_unlocks_async_calls(self, anyio_backend):
586+
if anyio_backend == "trio":
587+
pytest.skip("trio does not support sniffing")
588+
585589
client = AsyncElasticsearch( # noqa: F821
586590
[
587591
NodeConfig(
@@ -596,20 +600,17 @@ async def test_sniff_on_start_close_unlocks_async_calls(self):
596600
)
597601

598602
# Start making _async_calls() before we cancel
599-
tasks = []
600-
start_time = event_loop.time()
601-
for _ in range(3):
602-
tasks.append(event_loop.create_task(client.info()))
603-
await asyncio.sleep(0)
604-
605-
# Close the transport while the sniffing task is active! :(
606-
await client.transport.close()
607-
608-
# Now we start waiting on all those _async_calls()
609-
await asyncio.gather(*tasks)
610-
end_time = event_loop.time()
611-
duration = end_time - start_time
603+
async with anyio.create_task_group() as tg:
604+
start_time = time.monotonic()
605+
for _ in range(3):
606+
tg.start_soon(client.info)
607+
await anyio.sleep(0)
608+
609+
# Close the transport while the sniffing task is active! :(
610+
await client.transport.close()
612611

612+
end_time = time.monotonic()
613+
duration = end_time - start_time
613614
# A lot quicker than 10 seconds defined in 'delay'
614615
assert duration < 1
615616

@@ -661,6 +662,7 @@ def sniffed_node_callback(
661662
assert ports == {9200, 124}
662663

663664

665+
@pytest.mark.anyio
664666
@pytest.mark.parametrize("headers", [{}, {"X-elastic-product": "BAD HEADER"}])
665667
async def test_unsupported_product_error(headers):
666668
client = AsyncElasticsearch(
@@ -690,6 +692,7 @@ async def test_unsupported_product_error(headers):
690692
)
691693

692694

695+
@pytest.mark.anyio
693696
@pytest.mark.parametrize("status", [401, 403, 413, 500])
694697
async def test_unsupported_product_error_not_raised_on_non_2xx(status):
695698
client = AsyncElasticsearch(
@@ -709,6 +712,7 @@ async def test_unsupported_product_error_not_raised_on_non_2xx(status):
709712
assert e.meta.status == status
710713

711714

715+
@pytest.mark.anyio
712716
@pytest.mark.parametrize("status", [404, 500])
713717
async def test_api_error_raised_before_product_error(status):
714718
client = AsyncElasticsearch(
@@ -737,6 +741,7 @@ async def test_api_error_raised_before_product_error(status):
737741
assert calls[0][0] == ("GET", "/")
738742

739743

744+
@pytest.mark.anyio
740745
@pytest.mark.parametrize(
741746
"headers",
742747
[

test_elasticsearch/test_dsl/_async/test_document.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,21 +582,21 @@ def test_meta_fields_can_be_set_directly_in_init() -> None:
582582
assert md.meta.id is p
583583

584584

585-
@pytest.mark.asyncio
585+
@pytest.mark.anyio
586586
async def test_save_no_index(async_mock_client: Any) -> None:
587587
md = MyDoc()
588588
with raises(ValidationException):
589589
await md.save(using="mock")
590590

591591

592-
@pytest.mark.asyncio
592+
@pytest.mark.anyio
593593
async def test_delete_no_index(async_mock_client: Any) -> None:
594594
md = MyDoc()
595595
with raises(ValidationException):
596596
await md.delete(using="mock")
597597

598598

599-
@pytest.mark.asyncio
599+
@pytest.mark.anyio
600600
async def test_update_no_fields() -> None:
601601
md = MyDoc()
602602
with raises(IllegalOperation):

test_elasticsearch/test_dsl/_async/test_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def test_index_template_can_have_order() -> None:
190190
assert {"index_patterns": ["i-*"], "order": 2} == it.to_dict()
191191

192192

193-
@pytest.mark.asyncio
193+
@pytest.mark.anyio
194194
async def test_index_template_save_result(async_mock_client: Any) -> None:
195195
it = AsyncIndexTemplate("test-template", "test-*")
196196

0 commit comments

Comments
 (0)