Skip to content

Commit 8a4f674

Browse files
committed
test(proxy): add test for vertex passthrough load balancing
Add a test that verifies _base_vertex_proxy_route uses get_available_deployment for proper load balancing instead of get_model_list. This ensures the correct deployment is selected from the router and vertex credentials are properly fetched. Also refactor the implementation to: - Use get_available_deployment instead of get_model_list - Add error handling for deployment retrieval - Improve code structure with try-except block
1 parent 587126a commit 8a4f674

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,13 +1584,17 @@ async def _base_vertex_proxy_route(
15841584
from litellm.proxy.proxy_server import llm_router
15851585

15861586
if llm_router:
1587-
deployments = llm_router.get_model_list(model_name=model_id)
1588-
if deployments:
1589-
deployment = deployments[0]
1590-
litellm_params = deployment.get("litellm_params", {})
1591-
if litellm_params.get("use_in_pass_through"):
1592-
vertex_project = litellm_params.get("vertex_project")
1593-
vertex_location = litellm_params.get("vertex_location")
1587+
try:
1588+
deployment = llm_router.get_available_deployment(model=model_id)
1589+
if deployment:
1590+
litellm_params = deployment.get("litellm_params", {})
1591+
if litellm_params.get("use_in_pass_through"):
1592+
vertex_project = litellm_params.get("vertex_project")
1593+
vertex_location = litellm_params.get("vertex_location")
1594+
except Exception as e:
1595+
verbose_proxy_logger.debug(
1596+
f"Error getting available deployment for model {model_id}: {e}"
1597+
)
15941598

15951599
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
15961600
project_id=vertex_project,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
2+
import pytest
3+
from unittest.mock import MagicMock, AsyncMock, patch
4+
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import _base_vertex_proxy_route
5+
from litellm.types.router import DeploymentTypedDict
6+
7+
@pytest.mark.asyncio
8+
async def test_vertex_passthrough_load_balancing():
9+
"""
10+
Test that _base_vertex_proxy_route uses llm_router.get_available_deployment
11+
instead of get_model_list to ensure load balancing works.
12+
"""
13+
# Setup mocks
14+
mock_request = MagicMock()
15+
mock_response = MagicMock()
16+
mock_handler = MagicMock()
17+
18+
# Mock the router
19+
mock_router = MagicMock()
20+
mock_deployment = {
21+
"litellm_params": {
22+
"model": "vertex_ai/gemini-pro",
23+
"vertex_project": "test-project-lb",
24+
"vertex_location": "us-central1-lb",
25+
"use_in_pass_through": True
26+
}
27+
}
28+
mock_router.get_available_deployment.return_value = mock_deployment
29+
30+
# Mock get_vertex_model_id_from_url to return a model ID
31+
with patch("litellm.llms.vertex_ai.common_utils.get_vertex_model_id_from_url", return_value="gemini-pro"), \
32+
patch("litellm.proxy.proxy_server.llm_router", mock_router), \
33+
patch("litellm.llms.vertex_ai.common_utils.get_vertex_project_id_from_url", return_value=None), \
34+
patch("litellm.llms.vertex_ai.common_utils.get_vertex_location_from_url", return_value=None), \
35+
patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") as mock_pt_router, \
36+
patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints._prepare_vertex_auth_headers", new_callable=AsyncMock) as mock_prep_headers, \
37+
patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") as mock_create_route, \
38+
patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth", new_callable=AsyncMock) as mock_auth:
39+
40+
# Setup additional mocks to avoid side effects
41+
mock_pt_router.get_vertex_credentials.return_value = MagicMock()
42+
mock_prep_headers.return_value = ({}, "https://test.url", False, "test-project-lb", "us-central1-lb")
43+
44+
mock_endpoint_func = AsyncMock()
45+
mock_create_route.return_value = mock_endpoint_func
46+
mock_auth.return_value = {}
47+
48+
# Execute
49+
await _base_vertex_proxy_route(
50+
endpoint="https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent",
51+
request=mock_request,
52+
fastapi_response=mock_response,
53+
get_vertex_pass_through_handler=mock_handler
54+
)
55+
56+
# Verify
57+
# 1. Check that get_available_deployment was called with the correct model ID
58+
mock_router.get_available_deployment.assert_called_once_with(model="gemini-pro")
59+
60+
# 2. Check that get_model_list was NOT called (this ensures we aren't doing the old logic)
61+
mock_router.get_model_list.assert_not_called()
62+
63+
# 3. Verify that the project and location from the deployment were used (passed to _prepare_vertex_auth_headers)
64+
# The args are: request, vertex_credentials, router_credentials, vertex_project, vertex_location, ...
65+
# We check the 4th and 5th args (index 3 and 4)
66+
call_args = mock_prep_headers.call_args
67+
assert call_args[1]['vertex_project'] == "test-project-lb"
68+
assert call_args[1]['vertex_location'] == "us-central1-lb"
69+

0 commit comments

Comments
 (0)