|
| 1 | + |
| 2 | +import pytest |
| 3 | +from unittest.mock import MagicMock, AsyncMock, patch |
| 4 | +from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import _base_vertex_proxy_route |
| 5 | +from litellm.types.router import DeploymentTypedDict |
| 6 | + |
| 7 | +@pytest.mark.asyncio |
| 8 | +async def test_vertex_passthrough_load_balancing(): |
| 9 | + """ |
| 10 | + Test that _base_vertex_proxy_route uses llm_router.get_available_deployment |
| 11 | + instead of get_model_list to ensure load balancing works. |
| 12 | + """ |
| 13 | + # Setup mocks |
| 14 | + mock_request = MagicMock() |
| 15 | + mock_response = MagicMock() |
| 16 | + mock_handler = MagicMock() |
| 17 | + |
| 18 | + # Mock the router |
| 19 | + mock_router = MagicMock() |
| 20 | + mock_deployment = { |
| 21 | + "litellm_params": { |
| 22 | + "model": "vertex_ai/gemini-pro", |
| 23 | + "vertex_project": "test-project-lb", |
| 24 | + "vertex_location": "us-central1-lb", |
| 25 | + "use_in_pass_through": True |
| 26 | + } |
| 27 | + } |
| 28 | + mock_router.get_available_deployment.return_value = mock_deployment |
| 29 | + |
| 30 | + # Mock get_vertex_model_id_from_url to return a model ID |
| 31 | + with patch("litellm.llms.vertex_ai.common_utils.get_vertex_model_id_from_url", return_value="gemini-pro"), \ |
| 32 | + patch("litellm.proxy.proxy_server.llm_router", mock_router), \ |
| 33 | + patch("litellm.llms.vertex_ai.common_utils.get_vertex_project_id_from_url", return_value=None), \ |
| 34 | + patch("litellm.llms.vertex_ai.common_utils.get_vertex_location_from_url", return_value=None), \ |
| 35 | + patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") as mock_pt_router, \ |
| 36 | + patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints._prepare_vertex_auth_headers", new_callable=AsyncMock) as mock_prep_headers, \ |
| 37 | + patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") as mock_create_route, \ |
| 38 | + patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth", new_callable=AsyncMock) as mock_auth: |
| 39 | + |
| 40 | + # Setup additional mocks to avoid side effects |
| 41 | + mock_pt_router.get_vertex_credentials.return_value = MagicMock() |
| 42 | + mock_prep_headers.return_value = ({}, "https://test.url", False, "test-project-lb", "us-central1-lb") |
| 43 | + |
| 44 | + mock_endpoint_func = AsyncMock() |
| 45 | + mock_create_route.return_value = mock_endpoint_func |
| 46 | + mock_auth.return_value = {} |
| 47 | + |
| 48 | + # Execute |
| 49 | + await _base_vertex_proxy_route( |
| 50 | + endpoint="https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent", |
| 51 | + request=mock_request, |
| 52 | + fastapi_response=mock_response, |
| 53 | + get_vertex_pass_through_handler=mock_handler |
| 54 | + ) |
| 55 | + |
| 56 | + # Verify |
| 57 | + # 1. Check that get_available_deployment was called with the correct model ID |
| 58 | + mock_router.get_available_deployment.assert_called_once_with(model="gemini-pro") |
| 59 | + |
| 60 | + # 2. Check that get_model_list was NOT called (this ensures we aren't doing the old logic) |
| 61 | + mock_router.get_model_list.assert_not_called() |
| 62 | + |
| 63 | + # 3. Verify that the project and location from the deployment were used (passed to _prepare_vertex_auth_headers) |
| 64 | + # The args are: request, vertex_credentials, router_credentials, vertex_project, vertex_location, ... |
| 65 | + # We check the 4th and 5th args (index 3 and 4) |
| 66 | + call_args = mock_prep_headers.call_args |
| 67 | + assert call_args[1]['vertex_project'] == "test-project-lb" |
| 68 | + assert call_args[1]['vertex_location'] == "us-central1-lb" |
| 69 | + |
0 commit comments