Skip to content

Commit 3682212

Browse files
Yijia Jinjetstream authors
authored andcommitted
Internal test
PiperOrigin-RevId: 724067864
1 parent 8179a49 commit 3682212

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

jetstream/core/config_lib.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
16-
# testtesttest
17-
1815
"""Configs of engines for the orchestrator to load."""
1916

2017
import dataclasses

jetstream/tests/core/test_orchestrator.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
"""
4343

4444
import unittest
45-
from parameterized import parameterized
4645
from jetstream.core import orchestrator
4746
from jetstream.core.proto import jetstream_pb2
4847
from jetstream.core.utils.return_sample import ReturnSample
@@ -51,7 +50,7 @@
5150

5251
class OrchestratorTest(unittest.IsolatedAsyncioTestCase):
5352

54-
def _setup_driver(self, interleaved_mode: bool = True):
53+
def _setup_driver_interleaved_mode(self):
5554
prefill_engine = mock_engine.TestEngine(
5655
batch_size=32, cache_length=256, weight=2.0
5756
)
@@ -65,14 +64,13 @@ def _setup_driver(self, interleaved_mode: bool = True):
6564
generate_engines=[generate_engine],
6665
prefill_params=[prefill_engine.load_params()],
6766
generate_params=[generate_engine.load_params()],
68-
interleaved_mode=interleaved_mode,
67+
interleaved_mode=True,
6968
)
7069
return driver
7170

72-
@parameterized.expand([True, False])
73-
async def test_orchestrator(self, interleaved_mode: bool):
71+
async def test_orchestrator_interleaved_mode(self):
7472
"""Test the multithreaded orchestration."""
75-
driver = self._setup_driver(interleaved_mode)
73+
driver = self._setup_driver_interleaved_mode()
7674
client = orchestrator.LLMOrchestrator(driver=driver)
7775

7876
# The string representation of np.array([[65, 66]]), [2] will be prepend
@@ -99,10 +97,9 @@ async def test_orchestrator(self, interleaved_mode: bool):
9997
driver.stop()
10098
print("Orchestrator driver stopped.")
10199

102-
@parameterized.expand([True, False])
103-
async def test_orchestrator_client_tokenization(self, interleaved_mode: bool):
100+
async def test_orchestrator_interleaved_mode_client_tokenization(self):
104101
"""Test the multithreaded orchestration."""
105-
driver = self._setup_driver(interleaved_mode)
102+
driver = self._setup_driver_interleaved_mode()
106103
client = orchestrator.LLMOrchestrator(driver=driver)
107104

108105
# The token ids of string "AB", [2] will be prepend
@@ -131,9 +128,8 @@ async def test_orchestrator_client_tokenization(self, interleaved_mode: bool):
131128
driver.stop()
132129
print("Orchestrator driver stopped.")
133130

134-
@parameterized.expand([True, False])
135-
def test_should_buffer_response(self, interleaved_mode: bool):
136-
driver = self._setup_driver(interleaved_mode)
131+
def test_should_buffer_response(self):
132+
driver = self._setup_driver_interleaved_mode()
137133
client = orchestrator.LLMOrchestrator(driver=driver)
138134
self.assertTrue(
139135
client.should_buffer_response(

0 commit comments

Comments
 (0)