1616# under the License.
1717
1818
19- import asyncio
19+ import time
20+ import anyio
2021import re
2122import warnings
2223from typing import Any , Dict , Optional
4041 UnsupportedProductError ,
4142)
4243
43- pytestmark = pytest .mark .asyncio
44-
4544
4645class DummyNode (BaseAsyncNode ):
4746 def __init__ (self , config : NodeConfig ):
@@ -175,6 +174,7 @@ def mark_live(self, connection):
175174}"""
176175
177176
177+ @pytest .mark .anyio
178178class TestTransport :
179179 async def test_request_timeout_extracted_from_params_and_passed (self ):
180180 client = AsyncElasticsearch (
@@ -378,6 +378,9 @@ async def test_override_mark_dead_mark_live(self):
378378 assert len (client .transport .node_pool ._alive_nodes ) == 2
379379 assert len (client .transport .node_pool ._dead_consecutive_failures ) == 0
380380
381+
382+ @pytest .mark .asyncio
383+ class TestSniffing :
381384 @pytest .mark .parametrize (
382385 ["nodes_info_response" , "node_host" ],
383386 [(CLUSTER_NODES , "1.1.1.1" ), (CLUSTER_NODES_7x_PUBLISH_HOST , "somehost.tld" )],
@@ -528,23 +531,22 @@ async def test_sniff_on_node_failure_triggers(self, extra_key, extra_value):
528531 assert len (client .transport .node_pool ) == 3
529532
530533 async def test_sniff_after_n_seconds (self ):
531- event_loop = asyncio .get_running_loop ()
532534 client = AsyncElasticsearch ( # noqa: F821
533535 [NodeConfig ("http" , "localhost" , 9200 , _extras = {"data" : CLUSTER_NODES })],
534536 node_class = DummyNode ,
535537 min_delay_between_sniffing = 5 ,
536538 )
537- client .transport ._last_sniffed_at = event_loop . time ()
539+ client .transport ._last_sniffed_at = time . monotonic ()
538540
539541 await client .info ()
540542
541543 for _ in range (4 ):
542544 await client .info ()
543- await asyncio .sleep (0 )
545+ await anyio .sleep (0 )
544546
545547 assert 1 == len (client .transport .node_pool )
546548
547- client .transport ._last_sniffed_at = event_loop . time () - 5.1
549+ client .transport ._last_sniffed_at = time . monotonic () - 5.1
548550
549551 await client .info ()
550552 await client .transport ._sniffing_task # Need to wait for the sniffing task to complete
@@ -554,9 +556,9 @@ async def test_sniff_after_n_seconds(self):
554556 node .base_url for node in client .transport .node_pool .all ()
555557 )
556558 assert (
557- event_loop . time () - 1
559+ time . monotonic () - 1
558560 < client .transport ._last_sniffed_at
559- < event_loop . time () + 0.01
561+ < time . monotonic () + 0.01
560562 )
561563
562564 @pytest .mark .parametrize (
@@ -580,8 +582,10 @@ async def test_sniffing_disabled_on_elastic_cloud(self, kwargs):
580582 == "Sniffing should not be enabled when connecting to Elastic Cloud"
581583 )
582584
583- async def test_sniff_on_start_close_unlocks_async_calls (self ):
584- event_loop = asyncio .get_running_loop ()
585+ async def test_sniff_on_start_close_unlocks_async_calls (self , anyio_backend ):
586+ if anyio_backend == "trio" :
587+ pytest .skip ("trio does not support sniffing" )
588+
585589 client = AsyncElasticsearch ( # noqa: F821
586590 [
587591 NodeConfig (
@@ -596,20 +600,17 @@ async def test_sniff_on_start_close_unlocks_async_calls(self):
596600 )
597601
598602 # Start making _async_calls() before we cancel
599- tasks = []
600- start_time = event_loop .time ()
601- for _ in range (3 ):
602- tasks .append (event_loop .create_task (client .info ()))
603- await asyncio .sleep (0 )
604-
605- # Close the transport while the sniffing task is active! :(
606- await client .transport .close ()
607-
608- # Now we start waiting on all those _async_calls()
609- await asyncio .gather (* tasks )
610- end_time = event_loop .time ()
611- duration = end_time - start_time
603+ async with anyio .create_task_group () as tg :
604+ start_time = time .monotonic ()
605+ for _ in range (3 ):
606+ tg .start_soon (client .info )
607+ await anyio .sleep (0 )
608+
609+ # Close the transport while the sniffing task is active! :(
610+ await client .transport .close ()
612611
612+ end_time = time .monotonic ()
613+ duration = end_time - start_time
613614 # A lot quicker than 10 seconds defined in 'delay'
614615 assert duration < 1
615616
@@ -661,6 +662,7 @@ def sniffed_node_callback(
661662 assert ports == {9200 , 124 }
662663
663664
665+ @pytest .mark .anyio
664666@pytest .mark .parametrize ("headers" , [{}, {"X-elastic-product" : "BAD HEADER" }])
665667async def test_unsupported_product_error (headers ):
666668 client = AsyncElasticsearch (
@@ -690,6 +692,7 @@ async def test_unsupported_product_error(headers):
690692 )
691693
692694
695+ @pytest .mark .anyio
693696@pytest .mark .parametrize ("status" , [401 , 403 , 413 , 500 ])
694697async def test_unsupported_product_error_not_raised_on_non_2xx (status ):
695698 client = AsyncElasticsearch (
@@ -709,6 +712,7 @@ async def test_unsupported_product_error_not_raised_on_non_2xx(status):
709712 assert e .meta .status == status
710713
711714
715+ @pytest .mark .anyio
712716@pytest .mark .parametrize ("status" , [404 , 500 ])
713717async def test_api_error_raised_before_product_error (status ):
714718 client = AsyncElasticsearch (
@@ -737,6 +741,7 @@ async def test_api_error_raised_before_product_error(status):
737741 assert calls [0 ][0 ] == ("GET" , "/" )
738742
739743
744+ @pytest .mark .anyio
740745@pytest .mark .parametrize (
741746 "headers" ,
742747 [
0 commit comments