Skip to content

Commit cb2c137

Browse files
committed
fix(vertex_ai): improve passthrough endpoint url parsing and construction (#17402)
1 parent dc7c2b9 commit cb2c137

File tree

3 files changed

+85
-7
lines changed

3 files changed

+85
-7
lines changed

litellm/llms/vertex_ai/common_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,16 @@ def get_vertex_location_from_url(url: str) -> Optional[str]:
647647
return match.group(1) if match else None
648648

649649

650+
def get_vertex_model_id_from_url(url: str) -> Optional[str]:
651+
"""
652+
Get the vertex model id from the url
653+
654+
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
655+
"""
656+
match = re.search(r"/models/([^/:]+)", url)
657+
return match.group(1) if match else None
658+
659+
650660
def replace_project_and_location_in_route(
651661
requested_route: str, vertex_project: str, vertex_location: str
652662
) -> str:
@@ -696,6 +706,15 @@ def construct_target_url(
696706
if "cachedContent" in requested_route:
697707
vertex_version = "v1beta1"
698708

709+
# Check if the requested route starts with a version
710+
# e.g. /v1beta1/publishers/google/models/gemini-3-pro-preview:streamGenerateContent
711+
if requested_route.startswith("/v1/"):
712+
vertex_version = "v1"
713+
requested_route = requested_route.replace("/v1/", "/", 1)
714+
elif requested_route.startswith("/v1beta1/"):
715+
vertex_version = "v1beta1"
716+
requested_route = requested_route.replace("/v1beta1/", "/", 1)
717+
699718
base_requested_route = "{}/projects/{}/locations/{}".format(
700719
vertex_version, vertex_project, vertex_location
701720
)

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,7 @@ async def _base_vertex_proxy_route(
15471547
from litellm.llms.vertex_ai.common_utils import (
15481548
construct_target_url,
15491549
get_vertex_location_from_url,
1550+
get_vertex_model_id_from_url,
15501551
get_vertex_project_id_from_url,
15511552
)
15521553

@@ -1576,6 +1577,21 @@ async def _base_vertex_proxy_route(
15761577
vertex_location=vertex_location,
15771578
)
15781579

1580+
if vertex_project is None or vertex_location is None:
1581+
# Check if model is in router config
1582+
model_id = get_vertex_model_id_from_url(endpoint)
1583+
if model_id:
1584+
from litellm.proxy.proxy_server import llm_router
1585+
1586+
if llm_router:
1587+
deployments = llm_router.get_model_list(model_name=model_id)
1588+
if deployments:
1589+
deployment = deployments[0]
1590+
litellm_params = deployment.get("litellm_params", {})
1591+
if litellm_params.get("use_in_pass_through"):
1592+
vertex_project = litellm_params.get("vertex_project")
1593+
vertex_location = litellm_params.get("vertex_location")
1594+
15791595
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
15801596
project_id=vertex_project,
15811597
location=vertex_location,

tests/test_litellm/llms/vertex_ai/test_vertex_ai_common_utils.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import sys
3-
from typing import Any, Dict
4-
from unittest.mock import MagicMock, call, patch
3+
from unittest.mock import patch
54

65
import pytest
76

@@ -11,7 +10,6 @@
1110
0, os.path.abspath("../../..")
1211
) # Adds the parent directory to the system path
1312

14-
import litellm
1513
from litellm.llms.vertex_ai.common_utils import (
1614
_get_vertex_url,
1715
convert_anyof_null_to_nullable,
@@ -798,9 +796,54 @@ def test_fix_enum_empty_strings():
798796
assert "mobile" in enum_values
799797
assert "tablet" in enum_values
800798

801-
# 3. Other properties preserved
802-
assert input_schema["properties"]["user_agent_type"]["type"] == "string"
803-
assert input_schema["properties"]["user_agent_type"]["description"] == "Device type for user agent"
799+
800+
def test_get_vertex_model_id_from_url():
801+
"""Test get_vertex_model_id_from_url with various URLs"""
802+
from litellm.llms.vertex_ai.common_utils import get_vertex_model_id_from_url
803+
804+
# Test with valid URL
805+
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
806+
model_id = get_vertex_model_id_from_url(url)
807+
assert model_id == "gemini-pro"
808+
809+
# Test with invalid URL
810+
url = "https://invalid-url.com"
811+
model_id = get_vertex_model_id_from_url(url)
812+
assert model_id is None
813+
814+
815+
def test_construct_target_url_with_version_prefix():
816+
"""Test construct_target_url with version prefixes"""
817+
from litellm.llms.vertex_ai.common_utils import construct_target_url
818+
819+
# Test with /v1/ prefix
820+
url = "/v1/publishers/google/models/gemini-pro:streamGenerateContent"
821+
vertex_project = "test-project"
822+
vertex_location = "us-central1"
823+
base_url = "https://us-central1-aiplatform.googleapis.com"
824+
825+
target_url = construct_target_url(
826+
base_url=base_url,
827+
requested_route=url,
828+
vertex_project=vertex_project,
829+
vertex_location=vertex_location,
830+
)
831+
832+
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
833+
assert str(target_url) == expected_url
834+
835+
# Test with /v1beta1/ prefix
836+
url = "/v1beta1/publishers/google/models/gemini-pro:streamGenerateContent"
837+
838+
target_url = construct_target_url(
839+
base_url=base_url,
840+
requested_route=url,
841+
vertex_project=vertex_project,
842+
vertex_location=vertex_location,
843+
)
844+
845+
expected_url = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
846+
assert str(target_url) == expected_url
804847

805848

806849
def test_fix_enum_types():
@@ -862,7 +905,7 @@ def test_fix_enum_types():
862905
"truncateMode": {
863906
"enum": ["auto", "none", "start", "end"], # Kept - string type
864907
"type": "string",
865-
"description": "How to truncate content"
908+
"description": "How to truncate content",
866909
},
867910
"maxLength": { # enum removed
868911
"type": "integer",

0 commit comments

Comments
 (0)