Skip to content

Commit 8d3b5e9

Browse files
authored
Recover PR 181 and add pull ready tag (#183)
1 parent 3682212 commit 8d3b5e9

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

jetstream/tests/core/test_orchestrator.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"""
4343

4444
import unittest
45+
from parameterized import parameterized
4546
from jetstream.core import orchestrator
4647
from jetstream.core.proto import jetstream_pb2
4748
from jetstream.core.utils.return_sample import ReturnSample
@@ -50,7 +51,7 @@
5051

5152
class OrchestratorTest(unittest.IsolatedAsyncioTestCase):
5253

53-
def _setup_driver_interleaved_mode(self):
54+
def _setup_driver(self, interleaved_mode: bool = True):
5455
prefill_engine = mock_engine.TestEngine(
5556
batch_size=32, cache_length=256, weight=2.0
5657
)
@@ -64,13 +65,14 @@ def _setup_driver_interleaved_mode(self):
6465
generate_engines=[generate_engine],
6566
prefill_params=[prefill_engine.load_params()],
6667
generate_params=[generate_engine.load_params()],
67-
interleaved_mode=True,
68+
interleaved_mode=interleaved_mode,
6869
)
6970
return driver
7071

71-
async def test_orchestrator_interleaved_mode(self):
72+
@parameterized.expand([True, False])
73+
async def test_orchestrator(self, interleaved_mode: bool):
7274
"""Test the multithreaded orchestration."""
73-
driver = self._setup_driver_interleaved_mode()
75+
driver = self._setup_driver(interleaved_mode)
7476
client = orchestrator.LLMOrchestrator(driver=driver)
7577

7678
# The string representation of np.array([[65, 66]]), [2] will be prepend
@@ -97,9 +99,10 @@ async def test_orchestrator_interleaved_mode(self):
9799
driver.stop()
98100
print("Orchestrator driver stopped.")
99101

100-
async def test_orchestrator_interleaved_mode_client_tokenization(self):
102+
@parameterized.expand([True, False])
103+
async def test_orchestrator_client_tokenization(self, interleaved_mode: bool):
101104
"""Test the multithreaded orchestration."""
102-
driver = self._setup_driver_interleaved_mode()
105+
driver = self._setup_driver(interleaved_mode)
103106
client = orchestrator.LLMOrchestrator(driver=driver)
104107

105108
# The token ids of string "AB", [2] will be prepend
@@ -128,8 +131,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
128131
driver.stop()
129132
print("Orchestrator driver stopped.")
130133

131-
def test_should_buffer_response(self):
132-
driver = self._setup_driver_interleaved_mode()
134+
@parameterized.expand([True, False])
135+
def test_should_buffer_response(self, interleaved_mode: bool):
136+
driver = self._setup_driver(interleaved_mode)
133137
client = orchestrator.LLMOrchestrator(driver=driver)
134138
self.assertTrue(
135139
client.should_buffer_response(

0 commit comments

Comments
 (0)