4242"""
4343
4444import unittest
45+ from parameterized import parameterized
4546from jetstream .core import orchestrator
4647from jetstream .core .proto import jetstream_pb2
4748from jetstream .core .utils .return_sample import ReturnSample
5051
5152class OrchestratorTest (unittest .IsolatedAsyncioTestCase ):
5253
53- def _setup_driver_interleaved_mode (self ):
54+ def _setup_driver (self , interleaved_mode : bool = True ):
5455 prefill_engine = mock_engine .TestEngine (
5556 batch_size = 32 , cache_length = 256 , weight = 2.0
5657 )
@@ -64,13 +65,14 @@ def _setup_driver_interleaved_mode(self):
6465 generate_engines = [generate_engine ],
6566 prefill_params = [prefill_engine .load_params ()],
6667 generate_params = [generate_engine .load_params ()],
67- interleaved_mode = True ,
68+ interleaved_mode = interleaved_mode ,
6869 )
6970 return driver
7071
71- async def test_orchestrator_interleaved_mode (self ):
72+ @parameterized .expand ([True , False ])
73+ async def test_orchestrator (self , interleaved_mode : bool ):
7274 """Test the multithreaded orchestration."""
73- driver = self ._setup_driver_interleaved_mode ( )
75+ driver = self ._setup_driver ( interleaved_mode )
7476 client = orchestrator .LLMOrchestrator (driver = driver )
7577
7678 # The string representation of np.array([[65, 66]]), [2] will be prepend
@@ -97,9 +99,10 @@ async def test_orchestrator_interleaved_mode(self):
9799 driver .stop ()
98100 print ("Orchestrator driver stopped." )
99101
100- async def test_orchestrator_interleaved_mode_client_tokenization (self ):
102+ @parameterized .expand ([True , False ])
103+ async def test_orchestrator_client_tokenization (self , interleaved_mode : bool ):
101104 """Test the multithreaded orchestration."""
102- driver = self ._setup_driver_interleaved_mode ( )
105+ driver = self ._setup_driver ( interleaved_mode )
103106 client = orchestrator .LLMOrchestrator (driver = driver )
104107
105108 # The token ids of string "AB", [2] will be prepend
@@ -128,8 +131,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
128131 driver .stop ()
129132 print ("Orchestrator driver stopped." )
130133
131- def test_should_buffer_response (self ):
132- driver = self ._setup_driver_interleaved_mode ()
134+ @parameterized .expand ([True , False ])
135+ def test_should_buffer_response (self , interleaved_mode : bool ):
136+ driver = self ._setup_driver (interleaved_mode )
133137 client = orchestrator .LLMOrchestrator (driver = driver )
134138 self .assertTrue (
135139 client .should_buffer_response (
0 commit comments