4242"""
4343
4444import unittest
45- from parameterized import parameterized
4645from jetstream .core import orchestrator
4746from jetstream .core .proto import jetstream_pb2
4847from jetstream .core .utils .return_sample import ReturnSample
5150
5251class OrchestratorTest (unittest .IsolatedAsyncioTestCase ):
5352
54- def _setup_driver (self , interleaved_mode : bool = True ):
53+ def _setup_driver_interleaved_mode (self ):
5554 prefill_engine = mock_engine .TestEngine (
5655 batch_size = 32 , cache_length = 256 , weight = 2.0
5756 )
@@ -65,14 +64,13 @@ def _setup_driver(self, interleaved_mode: bool = True):
6564 generate_engines = [generate_engine ],
6665 prefill_params = [prefill_engine .load_params ()],
6766 generate_params = [generate_engine .load_params ()],
68- interleaved_mode = interleaved_mode ,
67+ interleaved_mode = True ,
6968 )
7069 return driver
7170
72- @parameterized .expand ([True , False ])
73- async def test_orchestrator (self , interleaved_mode : bool ):
71+ async def test_orchestrator_interleaved_mode (self ):
7472 """Test the multithreaded orchestration."""
75- driver = self ._setup_driver ( interleaved_mode )
73+ driver = self ._setup_driver_interleaved_mode ( )
7674 client = orchestrator .LLMOrchestrator (driver = driver )
7775
7876 # The string representation of np.array([[65, 66]]), [2] will be prepend
@@ -99,10 +97,9 @@ async def test_orchestrator(self, interleaved_mode: bool):
9997 driver .stop ()
10098 print ("Orchestrator driver stopped." )
10199
102- @parameterized .expand ([True , False ])
103- async def test_orchestrator_client_tokenization (self , interleaved_mode : bool ):
100+ async def test_orchestrator_interleaved_mode_client_tokenization (self ):
104101 """Test the multithreaded orchestration."""
105- driver = self ._setup_driver ( interleaved_mode )
102+ driver = self ._setup_driver_interleaved_mode ( )
106103 client = orchestrator .LLMOrchestrator (driver = driver )
107104
108105 # The token ids of string "AB", [2] will be prepend
@@ -131,9 +128,8 @@ async def test_orchestrator_client_tokenization(self, interleaved_mode: bool):
131128 driver .stop ()
132129 print ("Orchestrator driver stopped." )
133130
134- @parameterized .expand ([True , False ])
135- def test_should_buffer_response (self , interleaved_mode : bool ):
136- driver = self ._setup_driver (interleaved_mode )
131+ def test_should_buffer_response (self ):
132+ driver = self ._setup_driver_interleaved_mode ()
137133 client = orchestrator .LLMOrchestrator (driver = driver )
138134 self .assertTrue (
139135 client .should_buffer_response (
0 commit comments