Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 78 additions & 63 deletions agent_runtimes/langchain_agent/agent_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any, Callable, Dict

# Third-Party
from langchain.agents import AgentExecutor, create_openai_functions_agent
Expand Down Expand Up @@ -54,7 +55,7 @@ def create_llm(config: AgentConfig) -> BaseChatModel:
provider = config.llm_provider.lower()

# Common LLM arguments
common_args = {
common_args: Dict[str, Any] = {
"temperature": config.temperature,
"streaming": config.streaming_enabled,
}
Expand All @@ -64,68 +65,89 @@ def create_llm(config: AgentConfig) -> BaseChatModel:
if config.top_p:
common_args["top_p"] = config.top_p

if provider == "openai":
if not config.openai_api_key:
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")
# Provider factory functions
providers: Dict[str, Callable[[AgentConfig, Dict[str, Any]], BaseChatModel]] = {
"openai": _create_openai_llm,
"azure": _create_azure_llm,
"bedrock": _create_bedrock_llm,
"ollama": _create_ollama_llm,
"anthropic": _create_anthropic_llm,
}

openai_args = {"model": config.default_model, "api_key": config.openai_api_key, **common_args}
if provider not in providers:
raise ValueError(f"Unsupported LLM provider: {provider}. " f"Supported providers: {', '.join(providers.keys())}")

if config.openai_base_url:
openai_args["base_url"] = config.openai_base_url
if config.openai_organization:
openai_args["organization"] = config.openai_organization
return providers[provider](config, common_args)

return ChatOpenAI(**openai_args)

elif provider == "azure":
if not all([config.azure_openai_api_key, config.azure_openai_endpoint, config.azure_deployment_name]):
raise ValueError(
"Azure OpenAI requires AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_DEPLOYMENT_NAME"
)
def _create_openai_llm(config: AgentConfig, common_args: Dict[str, Any]) -> BaseChatModel:
"""Create OpenAI LLM instance."""

return AzureChatOpenAI(
api_key=config.azure_openai_api_key,
azure_endpoint=config.azure_openai_endpoint,
api_version=config.azure_openai_api_version,
azure_deployment=config.azure_deployment_name,
**common_args,
)

elif provider == "bedrock":
if BedrockChat is None:
raise ImportError("langchain-aws is required for Bedrock support. Install with: pip install langchain-aws")
if not all([config.aws_access_key_id, config.aws_secret_access_key, config.bedrock_model_id]):
raise ValueError("AWS Bedrock requires AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and BEDROCK_MODEL_ID")

return BedrockChat(
model_id=config.bedrock_model_id,
region_name=config.aws_region,
credentials_profile_name=None, # Use environment variables
**common_args,
)

elif provider == "ollama":
if ChatOllama is None:
raise ImportError(
"langchain-community is required for OLLAMA support. Install with: pip install langchain-community"
)
if not config.ollama_model:
raise ValueError("OLLAMA_MODEL is required for OLLAMA provider")
if not config.openai_api_key:
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")

return ChatOllama(model=config.ollama_model, base_url=config.ollama_base_url, **common_args)
openai_args = {"model": config.default_model, "api_key": config.openai_api_key, **common_args}

elif provider == "anthropic":
if ChatAnthropic is None:
raise ImportError(
"langchain-anthropic is required for Anthropic support. Install with: pip install langchain-anthropic"
)
if not config.anthropic_api_key:
raise ValueError("ANTHROPIC_API_KEY is required for Anthropic provider")
if config.openai_base_url:
openai_args["base_url"] = config.openai_base_url
if config.openai_organization:
openai_args["organization"] = config.openai_organization

return ChatOpenAI(**openai_args)


def _create_azure_llm(config: AgentConfig, common_args: Dict[str, Any]) -> BaseChatModel:
"""Create Azure OpenAI LLM instance."""

required_fields = [config.azure_openai_api_key, config.azure_openai_endpoint, config.azure_deployment_name]

if not all(required_fields):
raise ValueError("Azure OpenAI requires AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_DEPLOYMENT_NAME")

return AzureChatOpenAI(
api_key=config.azure_openai_api_key, azure_endpoint=config.azure_openai_endpoint, api_version=config.azure_openai_api_version, azure_deployment=config.azure_deployment_name, **common_args
)


def _create_bedrock_llm(config: AgentConfig, common_args: Dict[str, Any]) -> BaseChatModel:
"""Create AWS Bedrock LLM instance."""

return ChatAnthropic(model=config.default_model, api_key=config.anthropic_api_key, **common_args)
if BedrockChat is None:
raise ImportError("langchain-aws is required for Bedrock support. " "Install with: pip install langchain-aws")

else:
raise ValueError(f"Unsupported LLM provider: {provider}. Supported: openai, azure, bedrock, ollama, anthropic")
required_fields = [config.aws_access_key_id, config.aws_secret_access_key, config.bedrock_model_id]

if not all(required_fields):
raise ValueError("AWS Bedrock requires AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and BEDROCK_MODEL_ID")

return BedrockChat(
model_id=config.bedrock_model_id,
region_name=config.aws_region,
credentials_profile_name=None, # Use environment variables
**common_args,
)


def _create_ollama_llm(config: AgentConfig, common_args: Dict[str, Any]) -> BaseChatModel:
"""Create OLLAMA LLM instance."""
if ChatOllama is None:
raise ImportError("langchain-community is required for OLLAMA support. " "Install with: pip install langchain-community")

if not config.ollama_model:
raise ValueError("OLLAMA_MODEL is required for OLLAMA provider")

return ChatOllama(model=config.ollama_model, base_url=config.ollama_base_url, **common_args)


def _create_anthropic_llm(config: AgentConfig, common_args: Dict[str, Any]) -> BaseChatModel:
"""Create Anthropic LLM instance."""
if ChatAnthropic is None:
raise ImportError("langchain-anthropic is required for Anthropic support. " "Install with: pip install langchain-anthropic")

if not config.anthropic_api_key:
raise ValueError("ANTHROPIC_API_KEY is required for Anthropic provider")

return ChatAnthropic(model=config.default_model, api_key=config.anthropic_api_key, **common_args)


class MCPTool(BaseTool):
Expand Down Expand Up @@ -309,12 +331,7 @@ def is_initialized(self) -> bool:
async def check_readiness(self) -> bool:
"""Check if agent is ready to handle requests"""
try:
return (
self._initialized
and self.agent_executor is not None
and len(self.tools) >= 0 # Allow 0 tools for testing
and await self.test_gateway_connection()
)
return self._initialized and self.agent_executor is not None and len(self.tools) >= 0 and await self.test_gateway_connection() # Allow 0 tools for testing
except Exception:
return False

Expand Down Expand Up @@ -366,9 +383,7 @@ async def run_async(
chat_history.append(SystemMessage(content=msg["content"]))

# Run the agent
result = await self.agent_executor.ainvoke(
{"input": input_text, "chat_history": chat_history, "tool_names": [tool.name for tool in self.tools]}
)
result = await self.agent_executor.ainvoke({"input": input_text, "chat_history": chat_history, "tool_names": [tool.name for tool in self.tools]})

return result["output"]

Expand Down
Loading
Loading