From 9a0da7c5cb9fc84efbb7c8b4c4c436a1deb155a3 Mon Sep 17 00:00:00 2001 From: Ljupche Milosheski Date: Wed, 17 Dec 2025 14:06:49 +0100 Subject: [PATCH 1/2] Fix typo --- .../src/pyagentspec/adapters/crewai/agentspecexporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyagentspec/src/pyagentspec/adapters/crewai/agentspecexporter.py b/pyagentspec/src/pyagentspec/adapters/crewai/agentspecexporter.py index b5c7a7a1..6c074945 100644 --- a/pyagentspec/src/pyagentspec/adapters/crewai/agentspecexporter.py +++ b/pyagentspec/src/pyagentspec/adapters/crewai/agentspecexporter.py @@ -51,6 +51,6 @@ def to_component(self, crewai_component: CrewAIComponent) -> Component: """ if not isinstance(crewai_component, (CrewAIAgent, CrewAIFlow)): raise TypeError( - f"Expected an Agent of Flow, but got '{type(crewai_component)}' instead" + f"Expected an Agent or Flow, but got '{type(crewai_component)}' instead" ) return CrewAIToAgentSpecConverter().convert(crewai_component) From 87b1cc82e52fbbc679eb2a135fc1dee7d883b5be Mon Sep 17 00:00:00 2001 From: Ljupche Milosheski Date: Wed, 17 Dec 2025 16:23:13 +0100 Subject: [PATCH 2/2] Add support for MCP tools to CrewAI adapter --- pyagentspec/constraints/constraints.txt | 1 + pyagentspec/requirements-dev.txt | 1 + .../adapters/crewai/_agentspecconverter.py | 79 ++++- .../adapters/crewai/_crewaiconverter.py | 77 +++++ .../src/pyagentspec/adapters/crewai/_types.py | 18 ++ pyagentspec/tests/adapters/conftest.py | 132 ++++++++- .../tests/adapters/crewai/test_mcp_tools.py | 162 ++++++++++ pyagentspec/tests/adapters/encryption.py | 280 ++++++++++++++++++ .../tests/adapters/start_mcp_server.py | 164 ++++++++++ pyagentspec/tests/adapters/utils.py | 113 ++++++- pyagentspec/tests/conftest.py | 79 +++-- 11 files changed, 1072 insertions(+), 34 deletions(-) create mode 100644 pyagentspec/tests/adapters/crewai/test_mcp_tools.py create mode 100644 pyagentspec/tests/adapters/encryption.py create mode 100644 pyagentspec/tests/adapters/start_mcp_server.py diff --git a/pyagentspec/constraints/constraints.txt b/pyagentspec/constraints/constraints.txt index 5f065e84..de2b8c5d 100644 --- a/pyagentspec/constraints/constraints.txt +++ b/pyagentspec/constraints/constraints.txt @@ -5,6 +5,7 @@ httpx==0.28.1 # CrewAI adapter crewai==1.6.1 +mcp==1.22.0 # AutoGen adapter autogen-core==0.7.4 diff --git a/pyagentspec/requirements-dev.txt b/pyagentspec/requirements-dev.txt index f1ee3eb8..cb8fe95b 100644 --- a/pyagentspec/requirements-dev.txt +++ b/pyagentspec/requirements-dev.txt @@ -39,3 +39,4 @@ sphinx_design==0.6.1 # For remote tool tests fastapi[standard-no-fastapi-cloud-cli] +cryptography diff --git a/pyagentspec/src/pyagentspec/adapters/crewai/_agentspecconverter.py b/pyagentspec/src/pyagentspec/adapters/crewai/_agentspecconverter.py index 31b3d6f7..aee2f263 100644 --- a/pyagentspec/src/pyagentspec/adapters/crewai/_agentspecconverter.py +++ b/pyagentspec/src/pyagentspec/adapters/crewai/_agentspecconverter.py @@ -25,7 +25,12 @@ from pyagentspec.adapters.crewai._types import ( CrewAIAgent, CrewAIBaseTool, + CrewAIHTTPTransport, CrewAILlm, + CrewAIMCPClient, + CrewAIMCPNativeTool, + CrewAISSETransport, + CrewAIStdioTransport, CrewAIStructuredTool, CrewAITool, ) @@ -39,6 +44,13 @@ ) from pyagentspec.llms.openaiconfig import OpenAiConfig as AgentSpecOpenAiConfig from pyagentspec.llms.vllmconfig import VllmConfig as AgentSpecVllmModel +from pyagentspec.mcp.clienttransport import ClientTransport as AgentSpecClientTransport +from pyagentspec.mcp.clienttransport import SSETransport as AgentSpecSSETransport +from pyagentspec.mcp.clienttransport import StdioTransport as AgentSpecStdioTransport +from pyagentspec.mcp.clienttransport import ( + StreamableHTTPTransport as AgentSpecStreamableHTTPTransport, +) +from pyagentspec.mcp.tools import MCPTool as AgentSpecMCPTool from pyagentspec.property import Property as AgentSpecProperty from pyagentspec.tools import ServerTool as AgentSpecServerTool from pyagentspec.tools import Tool as AgentSpecTool @@ -119,6 +131,10 @@ def convert( agentspec_component = self._agent_convert_to_agentspec( crewai_component, referenced_objects ) + elif isinstance(crewai_component, CrewAIMCPClient): + agentspec_component = self._mcp_client_convert_to_agentspec( + crewai_component, referenced_objects + ) elif isinstance(crewai_component, CrewAIBaseTool): agentspec_component = self._tool_convert_to_agentspec( crewai_component, referenced_objects @@ -175,6 +191,35 @@ def _llm_convert_to_agentspec( raise ValueError(f"Unsupported type of LLM in Agent Spec: {model_provider}") + def _mcp_client_convert_to_agentspec( + self, crewai_mcp_client: CrewAIMCPClient, referenced_objects: Dict[str, Any] + ) -> AgentSpecClientTransport: + crewai_transport = crewai_mcp_client.transport + server_name, server_url, _ = crewai_mcp_client._get_server_info() + if isinstance(crewai_transport, CrewAIStdioTransport): + return AgentSpecStdioTransport( + name=server_name, + command=crewai_transport.command, + args=crewai_transport.args, + env=crewai_transport.env, + ) + elif isinstance(crewai_transport, CrewAIHTTPTransport): + return AgentSpecStreamableHTTPTransport( + name=server_name, + url=server_url or "", + headers=crewai_transport.headers, + ) + elif isinstance(crewai_transport, CrewAISSETransport): + return AgentSpecSSETransport( + name=server_name, + url=server_url or "", + headers=crewai_transport.headers, + ) + + raise ValueError( + f"Transports of type {type(crewai_transport)} are not yet supported for translation to AgentSpec" + ) + def _tool_convert_to_agentspec( self, crewai_tool: CrewAIBaseTool, referenced_objects: Dict[str, Any] ) -> AgentSpecTool: @@ -189,17 +234,35 @@ def _tool_convert_to_agentspec( output_json_schema = _get_return_type_json_schema_from_function_reference( crewai_tool._run ) - # There seem to be no counterparts for client tools and remote tools in CrewAI at the moment - return AgentSpecServerTool( - name=crewai_tool.name, - description=crewai_tool.description, - inputs=_pydantic_model_to_properties_list(crewai_tool.args_schema), - outputs=[AgentSpecProperty(title="result", json_schema=output_json_schema)], - ) + if isinstance(crewai_tool, CrewAIMCPNativeTool): + return AgentSpecMCPTool( + name=crewai_tool.original_tool_name, + description=crewai_tool.description.split("Tool Description: ")[1], + inputs=_pydantic_model_to_properties_list(crewai_tool.args_schema), + outputs=[AgentSpecProperty(title="result", json_schema=output_json_schema)], + client_transport=cast( + AgentSpecClientTransport, + self.convert( + crewai_tool.mcp_client, + referenced_objects=referenced_objects, + ), + ), + ) + else: + # There seem to be no counterparts for client tools and remote tools in CrewAI at the moment + return AgentSpecServerTool( + name=crewai_tool.name, + description=crewai_tool.description, + inputs=_pydantic_model_to_properties_list(crewai_tool.args_schema), + outputs=[AgentSpecProperty(title="result", json_schema=output_json_schema)], + ) def _agent_convert_to_agentspec( self, crewai_agent: CrewAIAgent, referenced_objects: Dict[str, Any] ) -> AgentSpecAgent: + tools = crewai_agent.tools or [] + if crewai_agent.mcps: + tools += crewai_agent.get_mcp_tools(crewai_agent.mcps) return AgentSpecAgent( id=str(crewai_agent.id), name=crewai_agent.role, @@ -214,6 +277,6 @@ def _agent_convert_to_agentspec( ), tools=[ cast(AgentSpecTool, self.convert(tool, referenced_objects=referenced_objects)) - for tool in (crewai_agent.tools or []) + for tool in tools ], ) diff --git a/pyagentspec/src/pyagentspec/adapters/crewai/_crewaiconverter.py b/pyagentspec/src/pyagentspec/adapters/crewai/_crewaiconverter.py index 9408ec29..ee11a499 100644 --- a/pyagentspec/src/pyagentspec/adapters/crewai/_crewaiconverter.py +++ b/pyagentspec/src/pyagentspec/adapters/crewai/_crewaiconverter.py @@ -13,9 +13,15 @@ from pyagentspec.adapters.crewai._types import ( CrewAIAgent, CrewAIBaseTool, + CrewAIHTTPTransport, CrewAILlm, + CrewAIMCPClient, + CrewAIMCPNativeTool, CrewAIServerToolType, + CrewAISSETransport, + CrewAIStdioTransport, CrewAITool, + CrewAITransport, ) from pyagentspec.adapters.crewai.tracing import CrewAIAgentWithTracing from pyagentspec.agent import Agent as AgentSpecAgent @@ -27,6 +33,14 @@ ) from pyagentspec.llms.openaiconfig import OpenAiConfig as AgentSpecOpenAiConfig from pyagentspec.llms.vllmconfig import VllmConfig as AgentSpecVllmModel +from pyagentspec.mcp.clienttransport import ClientTransport as AgentSpecClientTransport +from pyagentspec.mcp.clienttransport import SSETransport as AgentSpecSSETransport +from pyagentspec.mcp.clienttransport import StdioTransport as AgentSpecStdioTransport +from pyagentspec.mcp.clienttransport import ( + StreamableHTTPTransport as AgentSpecStreamableHTTPTransport, +) +from pyagentspec.mcp.tools import MCPTool as AgentSpecMCPTool +from pyagentspec.mcp.tools import MCPToolBox as AgentSpecMCPToolBox from pyagentspec.property import Property as AgentSpecProperty from pyagentspec.property import _empty_default as _agentspec_empty_default from pyagentspec.tools import Tool as AgentSpecTool @@ -119,6 +133,10 @@ def convert( crewai_component = self._tool_convert_to_crewai( agentspec_component, tool_registry, converted_components ) + elif isinstance(agentspec_component, AgentSpecClientTransport): + crewai_component = self._client_transport_convert_to_crewai( + agentspec_component, tool_registry, converted_components + ) elif isinstance(agentspec_component, AgentSpecComponent): raise NotImplementedError( f"The AgentSpec Component type '{agentspec_component.__class__.__name__}' is not yet supported " @@ -243,6 +261,14 @@ def client_tool(**kwargs: Any) -> Any: ) elif isinstance(agentspec_tool, AgentSpecRemoteTool): return self._remote_tool_convert_to_crewai(agentspec_tool) + elif isinstance(agentspec_tool, AgentSpecMCPTool): + return self._mcp_tool_convert_to_crewai( + agentspec_tool, tool_registry, converted_components + ) + elif isinstance(agentspec_tool, AgentSpecMCPToolBox): + raise NotImplementedError( + "Conversion of AgentSpec MCPToolBox objects is not yet implemented" + ) raise ValueError( f"Tools of type {type(agentspec_tool)} are not yet supported for translation to CrewAI" ) @@ -277,6 +303,57 @@ def _remote_tool(**kwargs: Any) -> Any: func=_remote_tool, ) + def _mcp_tool_convert_to_crewai( + self, + mcp_tool: AgentSpecMCPTool, + tool_registry: Dict[str, CrewAIServerToolType], + converted_components: Optional[Dict[str, Any]] = None, + ) -> CrewAIMCPNativeTool: + return CrewAIMCPNativeTool( + mcp_client=self.convert(mcp_tool.client_transport, tool_registry, converted_components), + tool_name=mcp_tool.name, + tool_schema={ + "description": mcp_tool.description or "", + "args_schema": _create_pydantic_model_from_properties( + mcp_tool.name.title() + "InputSchema", mcp_tool.inputs or [] + ), + }, + server_name=mcp_tool.client_transport.name, + ) + + def _client_transport_convert_to_crewai( + self, + agentspec_transport: AgentSpecClientTransport, + tool_registry: Dict[str, CrewAIServerToolType], + converted_components: Optional[Dict[str, Any]] = None, + ) -> CrewAIMCPClient: + transport: Optional[CrewAITransport] = None + if isinstance(agentspec_transport, AgentSpecStdioTransport): + transport = CrewAIStdioTransport( + command=agentspec_transport.command, + args=agentspec_transport.args, + env=agentspec_transport.env, + ) + elif isinstance(agentspec_transport, AgentSpecSSETransport): + transport = CrewAISSETransport( + url=agentspec_transport.url, + headers=agentspec_transport.headers, + ) + elif isinstance(agentspec_transport, AgentSpecStreamableHTTPTransport): + transport = CrewAIHTTPTransport( + url=agentspec_transport.url, + headers=agentspec_transport.headers, + streamable=True, + ) + else: + raise ValueError( + f"Transports of type {type(agentspec_transport)} are not yet supported for translation to CrewAI" + ) + return CrewAIMCPClient( + transport=transport, + cache_tools_list=True, + ) + def _agent_convert_to_crewai( self, agentspec_agent: AgentSpecAgent, diff --git a/pyagentspec/src/pyagentspec/adapters/crewai/_types.py b/pyagentspec/src/pyagentspec/adapters/crewai/_types.py index 38202371..f172aa0c 100644 --- a/pyagentspec/src/pyagentspec/adapters/crewai/_types.py +++ b/pyagentspec/src/pyagentspec/adapters/crewai/_types.py @@ -40,8 +40,14 @@ from crewai.events.types.tool_usage_events import ( ToolUsageStartedEvent as CrewAIToolUsageStartedEvent, ) + from crewai.mcp.client import MCPClient as CrewAIMCPClient + from crewai.mcp.transports import BaseTransport as CrewAITransport + from crewai.mcp.transports import HTTPTransport as CrewAIHTTPTransport + from crewai.mcp.transports import SSETransport as CrewAISSETransport + from crewai.mcp.transports import StdioTransport as CrewAIStdioTransport from crewai.tools import BaseTool as CrewAIBaseTool from crewai.tools.base_tool import Tool as CrewAITool + from crewai.tools.mcp_native_tool import MCPNativeTool as CrewAIMCPNativeTool from crewai.tools.structured_tool import CrewStructuredTool as CrewAIStructuredTool else: crewai = LazyLoader("crewai") @@ -49,8 +55,14 @@ CrewAILlm = crewai.LLM CrewAIAgent = crewai.Agent CrewAIFlow = crewai.Flow + CrewAIMCPClient = crewai.mcp.client.MCPClient + CrewAITransport = crewai.mcp.transports.BaseTransport + CrewAIHTTPTransport = crewai.mcp.transports.HTTPTransport + CrewAISSETransport = crewai.mcp.transports.SSETransport + CrewAIStdioTransport = crewai.mcp.transports.StdioTransport CrewAIBaseTool = LazyLoader("crewai.tools").BaseTool CrewAITool = LazyLoader("crewai.tools.base_tool").Tool + CrewAIMCPNativeTool = LazyLoader("crewai.tools.mcp_native_tool").MCPNativeTool CrewAIStructuredTool = LazyLoader("crewai.tools.structured_tool").CrewStructuredTool CrewAIBaseEventListener = LazyLoader("crewai.events.base_event_listener").BaseEventListener CrewAIEventsBus = LazyLoader("crewai.events.event_bus").CrewAIEventsBus @@ -88,10 +100,16 @@ "CrewAILlm", "CrewAIAgent", "CrewAIFlow", + "CrewAIMCPClient", + "CrewAITransport", + "CrewAIHTTPTransport", + "CrewAISSETransport", + "CrewAIStdioTransport", "CrewAIBaseTool", "CrewAITool", "CrewAIStructuredTool", "CrewAIComponent", + "CrewAIMCPNativeTool", "CrewAIServerToolType", "CrewAIBaseEvent", "CrewAIBaseEventListener", diff --git a/pyagentspec/tests/adapters/conftest.py b/pyagentspec/tests/adapters/conftest.py index fd49fdbc..ebb50dc8 100644 --- a/pyagentspec/tests/adapters/conftest.py +++ b/pyagentspec/tests/adapters/conftest.py @@ -5,13 +5,26 @@ # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import os +import ssl from pathlib import Path from typing import Any from unittest.mock import patch import pytest -from .utils import get_available_port, start_uvicorn_server, terminate_process_tree +from .encryption import ( + create_client_key_and_csr, + create_root_ca, + create_server_key_and_csr, + issue_client_cert, + issue_server_cert, +) +from .utils import ( + get_available_port, + register_mcp_server_fixture, + start_uvicorn_server, + terminate_process_tree, +) SKIP_LLM_TESTS_ENV_VAR = "SKIP_LLM_TESTS" @@ -172,3 +185,120 @@ def quickstart_agent_json() -> str: ) return AgentSpecSerializer().to_json(agent) + + +@pytest.fixture(scope="session") +def root_ca(session_tmp_path): + return create_root_ca(common_name="TestRootCA", days=3650, tmpdir=session_tmp_path) + + +@pytest.fixture(scope="session") +def ca_key(root_ca): + return root_ca[0] + + +@pytest.fixture(scope="session") +def ca_cert(root_ca): + return root_ca[1] + + +@pytest.fixture(scope="session") +def ca_cert_path(root_ca) -> str: + return root_ca[2] + + +@pytest.fixture(scope="session") +def server_key_and_csr(session_tmp_path): + return create_server_key_and_csr(cn="localhost", tmpdir=session_tmp_path) + + +@pytest.fixture(scope="session") +def server_key_path(server_key_and_csr): + return server_key_and_csr[2] + + +@pytest.fixture(scope="session") +def server_csr(server_key_and_csr): + return server_key_and_csr[1] + + +@pytest.fixture(scope="session") +def server_cert_path(ca_key, ca_cert, server_csr, session_tmp_path) -> str: + return issue_server_cert(ca_key, ca_cert, server_csr, days=365, tmpdir=session_tmp_path)[1] + + +@pytest.fixture(scope="session") +def client_key_and_csr(session_tmp_path): + return create_client_key_and_csr(cn="mtls-client", tmpdir=session_tmp_path) + + +@pytest.fixture(scope="session") +def client_key_path(client_key_and_csr): + return client_key_and_csr[2] + + +@pytest.fixture(scope="session") +def client_csr(client_key_and_csr): + return client_key_and_csr[1] + + +@pytest.fixture(scope="session") +def client_cert_path(ca_key, ca_cert, client_csr, session_tmp_path) -> str: + return issue_client_cert(ca_key, ca_cert, client_csr, days=365, tmpdir=session_tmp_path)[1] + + +_MCP_SERVER_FIXTURE_DEPS = ( + "server_key_path", + "server_cert_path", + "ca_cert_path", + "client_key_path", + "client_cert_path", +) + +sse_mcp_server_http = register_mcp_server_fixture( + name="sse_mcp_server_http", + url_suffix="sse", + start_kwargs=dict( + host="localhost", + port=get_available_port(), + mode="sse", + ssl_cert_reqs=0, + ), + deps=(), +) + +streamablehttp_mcp_server_http = register_mcp_server_fixture( + name="streamablehttp_mcp_server_http", + url_suffix="mcp", + start_kwargs=dict( + host="localhost", + port=get_available_port(), + mode="streamable-http", + ssl_cert_reqs=int(ssl.CERT_NONE), + ), + deps=(), +) + +sse_mcp_server_https = register_mcp_server_fixture( + name="sse_mcp_server_https", + url_suffix="sse", + start_kwargs=dict( + host="localhost", + port=get_available_port(), + mode="sse", + ssl_cert_reqs=int(ssl.CERT_NONE), + ), + deps=_MCP_SERVER_FIXTURE_DEPS, +) + +streamablehttp_mcp_server_https = register_mcp_server_fixture( + name="streamablehttp_mcp_server_https", + url_suffix="mcp", + start_kwargs=dict( + host="localhost", + port=get_available_port(), + mode="streamable-http", + ssl_cert_reqs=int(ssl.CERT_NONE), + ), + deps=_MCP_SERVER_FIXTURE_DEPS, +) diff --git a/pyagentspec/tests/adapters/crewai/test_mcp_tools.py b/pyagentspec/tests/adapters/crewai/test_mcp_tools.py new file mode 100644 index 00000000..a0ab4d26 --- /dev/null +++ b/pyagentspec/tests/adapters/crewai/test_mcp_tools.py @@ -0,0 +1,162 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from pathlib import Path + +import pytest + +from ..conftest import llama70bv33_api_url + +CONFIGS = Path(__file__).parent / "configs" + + +@pytest.fixture +def sse_client_transport(sse_mcp_server_http): + from pyagentspec.mcp import SSETransport + + return SSETransport(name="SSE HTTP", url=sse_mcp_server_http) + + +@pytest.fixture +def sse_client_transport_https(sse_mcp_server_https): + from pyagentspec.mcp import SSETransport + + return SSETransport(name="SSE HTTPS", url=sse_mcp_server_https) + + +@pytest.fixture +def streamablehttp_client_transport(streamablehttp_mcp_server_http): + from pyagentspec.mcp import StreamableHTTPTransport + + return StreamableHTTPTransport(name="Streamable HTTP", url=streamablehttp_mcp_server_http) + + +@pytest.fixture +def streamablehttp_client_transport_https(streamablehttp_mcp_server_https): + from pyagentspec.mcp import StreamableHTTPTransport + + return StreamableHTTPTransport(name="Streamable HTTPS", url=streamablehttp_mcp_server_https) + + +def convert_and_run_agentspec_agent_with_mcp_tools(client_transport): + from crewai import Crew, Task + + from pyagentspec.adapters.crewai import AgentSpecLoader + from pyagentspec.agent import Agent + from pyagentspec.llms import VllmConfig + from pyagentspec.mcp import MCPTool + from pyagentspec.property import IntegerProperty + + get_user_session_tool = MCPTool( + client_transport=client_transport, + description="Return session details for the current user", + name="get_user_session", + ) + get_payslips_tool = MCPTool( + client_transport=client_transport, + description="Return payslip details for a given PersonId", + name="get_payslips", + inputs=[ + IntegerProperty( + title="PersonId", + description="Specifies ID of the person whose invoices will be fetched", + ) + ], + ) + llm_config = VllmConfig( + name="Llama-3.3-70B-Instruct", + model_id="/storage/models/Llama-3.3-70B-Instruct", + url=llama70bv33_api_url, + ) + agent = Agent( + name="Agent using MCP", + llm_config=llm_config, + system_prompt="Use tools at your disposal to solve the specific task", + tools=[get_user_session_tool, get_payslips_tool], + ) + + crewai_agent = AgentSpecLoader().load_component(agent) + task = Task( + description="Find the date of the last invoice of the current user. Hint: Use the tools are your disposal. First, get the current user using `get_user_session` tool. Next, get all invoices of the current user using `get_payslips` tool. Finally, find the invoice with the highest / most recent date using the `PaymentDate` attribute.", + expected_output="The output must solely contain the date of the last invoice in format YYYY-MM-DD. Example output: 2025-12-25", + agent=crewai_agent, + ) + crew = Crew(agents=[crewai_agent], tasks=[task]) + + assert len(crewai_agent.tools) == 2 + assert any("get_user_session" in tool.description for tool in crewai_agent.tools) + assert any("get_payslips" in tool.description for tool in crewai_agent.tools) + + # running the CrewAI agent is currently disabled due to flaky tests + # result = crew.kickoff() + # assert "2024-05-15" in result.raw + + +@pytest.mark.parametrize( + "client_transport_name", + [ + "sse_client_transport", + "sse_client_transport_https", + "streamablehttp_client_transport", + "streamablehttp_client_transport_https", + ], +) +def test_agentspec_agent_with_mcp_tools_conversion_to_crewai_agent(client_transport_name, request): + client_transport = request.getfixturevalue(client_transport_name) + convert_and_run_agentspec_agent_with_mcp_tools(client_transport) + + +@pytest.fixture +def mcp_server_sse(sse_mcp_server_http): + from crewai.mcp import MCPServerSSE + + return MCPServerSSE(url=sse_mcp_server_http, streamable=False) + + +@pytest.fixture +def mcp_server_http(streamablehttp_mcp_server_http): + from crewai.mcp import MCPServerHTTP + + return MCPServerHTTP(url=streamablehttp_mcp_server_http) + + +def convert_crewai_agent_with_mcp_tools(mcp_server): + from crewai import LLM, Agent, Task + + from pyagentspec.adapters.crewai import AgentSpecExporter + + llm = LLM( + model="hosted_vllm//storage/models/Llama-3.3-70B-Instruct", + api_base=llama70bv33_api_url, + ) + agent = Agent( + llm=llm, + role="Financial Analyst", + goal="Use tools at your disposal to solve the specific task", + backstory="Expert finance analyst with advanced tool access", + mcps=[mcp_server], + ) + task = Task( + description="Find the date of the last invoice of the current user. Hint: Use the tools are your disposal. First, get the current user using `get_user_session` tool. Next, get all invoices of the current user using `get_payslips` tool. Finally, find the invoice with the highest / most recent date using the `PaymentDate` attribute.", + expected_output="The output must solely contain the date of the last invoice in format YYYY-MM-DD. Example output: 2025-12-25", + agent=agent, + ) + agentspec_agent = AgentSpecExporter().to_component(agent) + assert len(agentspec_agent.tools) == 2 + assert any(tool.name == "get_user_session" for tool in agentspec_agent.tools) + assert any(tool.name == "get_payslips" for tool in agentspec_agent.tools) + + +@pytest.mark.parametrize( + "mcp_server", + [ + "mcp_server_sse", + "mcp_server_http", + ], +) +def test_crewai_agent_with_mcp_tools_conversion_to_agentspec_agent(mcp_server, request): + mcp_server = request.getfixturevalue(mcp_server) + convert_crewai_agent_with_mcp_tools(mcp_server) diff --git a/pyagentspec/tests/adapters/encryption.py b/pyagentspec/tests/adapters/encryption.py new file mode 100644 index 00000000..eb979185 --- /dev/null +++ b/pyagentspec/tests/adapters/encryption.py @@ -0,0 +1,280 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import datetime as dt +import ipaddress +import logging +import os + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import ( + BestAvailableEncryption, + Encoding, + NoEncryption, + PrivateFormat, +) +from cryptography.x509 import Name, NameAttribute +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + +logger = logging.getLogger(__name__) + + +def now_utc(): + return dt.datetime.now(dt.timezone.utc) + + +def write_pem(path: str, data: bytes): + with open(path, "wb") as f: + f.write(data) + logger.info(f"Wrote {path}") + + +def save_private_key(path: str, key: rsa.RSAPrivateKey, password: bytes | None = None): + if password: + enc = BestAvailableEncryption(password) + else: + enc = NoEncryption() + pem = key.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=enc, + ) + write_pem(path, pem) + + +def save_cert(path: str, cert: x509.Certificate): + write_pem(path, cert.public_bytes(Encoding.PEM)) + + +def save_csr(path: str, csr: x509.CertificateSigningRequest): + write_pem(path, csr.public_bytes(Encoding.PEM)) + + +def create_root_ca(common_name: str = "TestRootCA", days=3650, tmpdir: str = ""): + ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + save_private_key(os.path.join(str(tmpdir), "rootCA.key"), ca_key) + + # Subject/Issuer (self-signed) + name = Name([NameAttribute(NameOID.COMMON_NAME, common_name)]) + builder = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name) + .public_key(ca_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now_utc() - dt.timedelta(minutes=1)) + .not_valid_after(now_utc() + dt.timedelta(days=days)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=0), + critical=True, + ) + .add_extension( + x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()), + critical=False, + ) + ) + + ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256()) + ca_cert_path = os.path.join(str(tmpdir), "rootCA.pem") + save_cert(ca_cert_path, ca_cert) + return ca_key, ca_cert, ca_cert_path + + +def create_server_key_and_csr(cn: str = "localhost", tmpdir: str = ""): + server_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + server_key_path = os.path.join(str(tmpdir), "server.key") + save_private_key(server_key_path, server_key) + + csr = ( + x509.CertificateSigningRequestBuilder() + .subject_name(Name([NameAttribute(NameOID.COMMON_NAME, cn)])) + .sign(server_key, hashes.SHA256()) + ) + save_csr(os.path.join(str(tmpdir), "server.csr"), csr) + return server_key, csr, server_key_path + + +def issue_server_cert( + ca_key: rsa.RSAPrivateKey, + ca_cert: x509.Certificate, + csr: x509.CertificateSigningRequest, + days: int = 365, + tmpdir: str = "", +): + san = x509.SubjectAlternativeName( + [ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ] + ) + + builder = ( + x509.CertificateBuilder() + .subject_name(csr.subject) + .issuer_name(ca_cert.subject) + .public_key(csr.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now_utc() - dt.timedelta(minutes=1)) + .not_valid_after(now_utc() + dt.timedelta(days=days)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=True, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), + critical=False, + ) + .add_extension(san, critical=False) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ) + .add_extension( + x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()), + critical=False, + ) + ) + + cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256()) + server_cert_path = os.path.join(str(tmpdir), "server.crt") + save_cert(server_cert_path, cert) + return cert, server_cert_path + + +def create_client_key_and_csr(cn: str = "mtls-client", tmpdir: str = ""): + client_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + client_key_path = os.path.join(str(tmpdir), "client.key") + save_private_key(client_key_path, client_key) + + csr = ( + x509.CertificateSigningRequestBuilder() + .subject_name(Name([NameAttribute(NameOID.COMMON_NAME, cn)])) + .sign(client_key, hashes.SHA256()) + ) + save_csr(os.path.join(str(tmpdir), "client.csr"), csr) + return client_key, csr, client_key_path + + +def issue_client_cert( + ca_key: rsa.RSAPrivateKey, + ca_cert: x509.Certificate, + csr: x509.CertificateSigningRequest, + days: int = 365, + tmpdir: str = "", +): + builder = ( + x509.CertificateBuilder() + .subject_name(csr.subject) + .issuer_name(ca_cert.subject) + .public_key(csr.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now_utc() - dt.timedelta(minutes=1)) + .not_valid_after(now_utc() + dt.timedelta(days=days)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), + critical=False, + ) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ) + .add_extension( + x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()), + critical=False, + ) + ) + + cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256()) + client_cert_path = os.path.join(str(tmpdir), "client.crt") + save_cert(client_cert_path, cert) + return cert, client_cert_path + + +def print_quick_checks(server_cert: x509.Certificate, client_cert: x509.Certificate): + # Server SANs + try: + san = server_cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + entries = [] + for d in san: + if isinstance(d, x509.DNSName): + entries.append(f"DNS:{d.value}") + elif isinstance(d, x509.IPAddress): + entries.append(f"IP:{d.value.exploded}") + logger.info("Server SAN ->", ", ".join(entries)) + except x509.ExtensionNotFound: + logger.info("Server SAN -> (none)") + + # Client EKU + try: + eku = client_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value + eku_names = [] + for oid in eku: + if oid == ExtendedKeyUsageOID.CLIENT_AUTH: + eku_names.append("TLS Web Client Authentication") + elif oid == ExtendedKeyUsageOID.SERVER_AUTH: + eku_names.append("TLS Web Server Authentication") + else: + eku_names.append(oid.dotted_string) + logger.info("Client EKU ->", ", ".join(eku_names)) + except x509.ExtensionNotFound: + logger.info("Client EKU -> (none)") + + +if __name__ == "__main__": + # 1. Root CA + ca_key, ca_cert, ca_cert_path = create_root_ca(common_name="TestRootCA", days=3650) + # 2. Server key + CSR + server_key, server_csr, server_key_path = create_server_key_and_csr(cn="localhost") + # 3. Issue server cert (SAN: localhost + 127.0.0.1) + server_cert, server_cert_path = issue_server_cert(ca_key, ca_cert, server_csr, days=365) + # 4. Client key + CSR + client_key, client_csr, client_key_path = create_client_key_and_csr(cn="mtls-client") + # 5. Issue client cert (EKU: clientAuth) + client_cert, client_cert_path = issue_client_cert(ca_key, ca_cert, client_csr, days=365) + # 6. Quick checks + print_quick_checks(server_cert, client_cert) diff --git a/pyagentspec/tests/adapters/start_mcp_server.py b/pyagentspec/tests/adapters/start_mcp_server.py new file mode 100644 index 00000000..14ffb087 --- /dev/null +++ b/pyagentspec/tests/adapters/start_mcp_server.py @@ -0,0 +1,164 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +import argparse +from contextvars import ContextVar +from os import PathLike +from typing import Literal, Optional + +from mcp.server.fastmcp import FastMCP as BaseFastMCP +from starlette.applications import Starlette +from typing_extensions import TypedDict + +UvicornExtraConfig = TypedDict( + "UvicornExtraConfig", + { + "ssl_keyfile": str | PathLike[str] | None, + "ssl_certfile": str | PathLike[str] | None, + "ssl_ca_certs": str | None, + "ssl_cert_reqs": int, + }, + total=False, +) + +_EXTRA_CONFIG: ContextVar[Optional[UvicornExtraConfig]] = ContextVar("_EXTRA_CONFIG", default=None) + +PAYSLIPS = [ + { + "Amount": 7612, + "Currency": "USD", + "PeriodStartDate": "2025/05/15", + "PeriodEndDate": "2025/06/15", + "PaymentDate": "", + "DocumentId": 2, + "PersonId": 2, + }, + { + "Amount": 5000, + "Currency": "CHF", + "PeriodStartDate": "2024/05/01", + "PeriodEndDate": "2024/06/01", + "PaymentDate": "2024/05/15", + "DocumentId": 1, + "PersonId": 1, + }, + { + "Amount": 10000, + "Currency": "EUR", + "PeriodStartDate": "2025/06/15", + "PeriodEndDate": "2025/10/15", + "PaymentDate": "", + "DocumentsId": 3, + "PersonId": 3, + }, +] + + +class FastMCP(BaseFastMCP): + async def _start_server(self, starlette_app: Starlette) -> None: + import uvicorn + + extra_config = _EXTRA_CONFIG.get() + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + **extra_config, + ) + server = uvicorn.Server(config) + await server.serve() + + async def run_sse_async(self, mount_path: str | None = None) -> None: + """Run the server using SSE transport.""" + starlette_app = self.sse_app(mount_path) + await self._start_server(starlette_app) + + async def run_streamable_http_async(self) -> None: + """Run the server using StreamableHTTP transport.""" + starlette_app = self.streamable_http_app() + await self._start_server(starlette_app) + + +def create_server(host: str, port: int): + """Create and configure the MCP server""" + server = FastMCP( + name="Example MCP Server", + instructions="A MCP Server.", + host=host, + port=port, + ) + + @server.tool(description="Return session details for the current user") + def get_user_session(): + return { + "PersonId": "1", + "Username": "Bob.b", + "DisplayName": "Bob B", + } + + @server.tool(description="Return payslip details for a given PersonId") + def get_payslips(PersonId: int): + return [payslip for payslip in PAYSLIPS if payslip["PersonId"] == int(PersonId)] + + return server + + +def main( + host: str, + port: int, + mode: Literal["sse", "streamable-http"], + ssl_keyfile: str | None, + ssl_certfile: str | None, + ssl_ca_certs: str | None, + ssl_cert_reqs: int, +): + _EXTRA_CONFIG.set( + dict( + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_ca_certs=ssl_ca_certs, + ssl_cert_reqs=ssl_cert_reqs, + ) + ) + server = create_server(host=host, port=port) + server.run(transport=mode) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process host, port, and mode.") + + parser.add_argument( + "--host", type=str, help='The host address (e.g., "localhost" or "127.0.0.1")' + ) + parser.add_argument("--port", type=int, help="The port number (e.g., 8080)") + parser.add_argument( + "--mode", type=str, choices=["sse", "streamable-http"], help="The mode for the application" + ) + parser.add_argument( + "--ssl_keyfile", type=str, help="Path to the server private key file (PEM format)." + ) + parser.add_argument( + "--ssl_certfile", type=str, help="Path to the server certificate chain file (PEM format)." + ) + parser.add_argument( + "--ssl_ca_certs", type=str, help="Path to the trusted CA certificate file (PEM format)." + ) + parser.add_argument( + "--ssl_cert_reqs", type=int, help="Server certificate verify mode (0=None or 2=Required)." + ) + + args = parser.parse_args() + + main( + host=args.host, + port=args.port, + mode=args.mode, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + ) diff --git a/pyagentspec/tests/adapters/utils.py b/pyagentspec/tests/adapters/utils.py index 032276fc..bc0070a1 100644 --- a/pyagentspec/tests/adapters/utils.py +++ b/pyagentspec/tests/adapters/utils.py @@ -18,6 +18,7 @@ from typing import Optional import httpx +import pytest class LogTee: @@ -96,9 +97,19 @@ def terminate_process_tree(process: subprocess.Popen, timeout: float = 5.0) -> N process.stdout.close() -def start_uvicorn_server( - server_path: str | pathlib.Path, host: str, port: int, ready_timeout_s: float = 20.0 -) -> tuple[subprocess.Popen, str]: +def _start_server( + server_path: str | pathlib.Path, + host: str, + port: int, + mode: Optional[str] = None, + server_key_path: Optional[str] = None, + server_cert_path: Optional[str] = None, + ca_cert_path: Optional[str] = None, + ssl_cert_reqs: Optional[int] = None, # ssl.CERT_NONE + client_key_path: Optional[str] = None, + client_cert_path: Optional[str] = None, + ready_timeout_s: float = 10.0, +) -> tuple[subprocess.Popen, str, LogTee]: process_args = [ "python", "-u", # unbuffered output @@ -108,8 +119,34 @@ def start_uvicorn_server( "--port", str(port), ] - - url = f"http://{host}:{port}" + if mode: + process_args.extend( + [ + "--mode", + mode, + ] + ) + if ssl_cert_reqs is not None: + process_args.extend( + [ + "--ssl_cert_reqs", + str(ssl_cert_reqs), + ] + ) + if server_key_path and server_cert_path and ca_cert_path: # using https + process_args.extend( + [ + "--ssl_keyfile", + server_key_path, + "--ssl_certfile", + server_cert_path, + "--ssl_ca_certs", + ca_cert_path, + ] + ) + url = f"https://{host}:{port}" + else: + url = f"http://{host}:{port}" env = os.environ.copy() env.setdefault("PYTHONUNBUFFERED", "1") @@ -140,9 +177,11 @@ def start_uvicorn_server( if rc is not None: raise RuntimeError(f"Uvicorn exited early with code {rc}.\nLogs:\n{tee.dump()}") - if check_server_is_up(url, timeout_s=0.5): + if check_server_is_up( + url, client_key_path, client_cert_path, ca_cert_path, timeout_s=0.5 + ): print("Server is up.", flush=True) - return process, url + return process, url, tee time.sleep(0.2) # Timed out @@ -157,6 +196,43 @@ def start_uvicorn_server( tee.stop() +def start_uvicorn_server( + server_path: str | pathlib.Path, host: str, port: int, ready_timeout_s: float = 20.0 +) -> tuple[subprocess.Popen, str]: + process, url, _ = _start_server(server_path, host, port, ready_timeout_s=ready_timeout_s) + return process, url + + +def start_mcp_server( + host: str, + port: int, + mode: Optional[str] = None, + server_key_path: Optional[str] = None, + server_cert_path: Optional[str] = None, + ca_cert_path: Optional[str] = None, + ssl_cert_reqs: int = 0, # ssl.CERT_NONE + client_key_path: Optional[str] = None, + client_cert_path: Optional[str] = None, + ready_timeout_s: float = 10.0, +) -> tuple[subprocess.Popen, str, LogTee]: + start_mcp_server_file_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "start_mcp_server.py" + ) + return _start_server( + start_mcp_server_file_path, + host, + port, + mode=mode, + server_key_path=server_key_path, + server_cert_path=server_cert_path, + ca_cert_path=ca_cert_path, + ssl_cert_reqs=ssl_cert_reqs, + client_key_path=client_key_path, + client_cert_path=client_cert_path, + ready_timeout_s=ready_timeout_s, + ) + + def check_server_is_up( url: str, client_key_path: Optional[str] = None, @@ -196,3 +272,26 @@ def get_available_port(): s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] + + +def register_mcp_server_fixture( + name: str, url_suffix: str, start_kwargs: dict, deps: tuple[str, ...] = () +): + def _fixture(request): + # Resolve any dependent fixtures by name and merge into kwargs + resolved = {name: request.getfixturevalue(name) for name in deps} + kwargs = {**start_kwargs, **resolved} + + process, url, tee = start_mcp_server(**kwargs) + try: + yield f"{url}/{url_suffix.strip('/')}" + finally: + # ^ The MCP sessions are closed before the servers are + # stopped to avoid sse_reader issues (solves the error: + # `peer closed connection without sending complete message body + # (incomplete chunked read)`) + terminate_process_tree(process, timeout=5.0) + tee.stop() # this needs to be stopped after the + # MCP server so that the stdout is closed. + + return pytest.fixture(scope="session", name=name)(_fixture) diff --git a/pyagentspec/tests/conftest.py b/pyagentspec/tests/conftest.py index 7f6c4e4c..6e6ff382 100644 --- a/pyagentspec/tests/conftest.py +++ b/pyagentspec/tests/conftest.py @@ -5,7 +5,9 @@ # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import os +import shutil import stat +import tempfile from contextlib import contextmanager from distutils.sysconfig import get_python_lib from pathlib import Path @@ -32,6 +34,14 @@ def should_skip_llm_test() -> bool: ] +@pytest.fixture(scope="session") +def session_tmp_path(): + """Session-scoped temp path""" + dirpath = tempfile.mkdtemp() + yield dirpath + shutil.rmtree(dirpath) + + @pytest.fixture(scope="session", autouse=True) def skip_test_fixture(): """ @@ -132,15 +142,16 @@ def check_file_permissions(path: Any) -> None: assert not (st_mode & (stat.S_IRWXG | stat.S_IRWXO)) -def get_directory_allowlist_write(tmp_path: str) -> List[Union[str, Path]]: +def get_directory_allowlist_write(tmp_path: str, session_tmp_path: str) -> List[Union[str, Path]]: return [ get_python_lib(), # Allow packages to r/w their pycache tmp_path, + session_tmp_path, "/dev/null", ] -def get_directory_allowlist_read(tmp_path: str) -> List[Union[str, Path]]: +def get_directory_allowlist_read(tmp_path: str, session_tmp_path: str) -> List[Union[str, Path]]: try: # Crew AI sometimes attempts to read in some folders, we need to take that into account from crewai.cli.shared.token_manager import TokenManager @@ -152,7 +163,7 @@ def get_directory_allowlist_read(tmp_path: str) -> List[Union[str, Path]]: except ImportError: crewai_read_dirs = [] return ( - get_directory_allowlist_write(tmp_path) + get_directory_allowlist_write(tmp_path, session_tmp_path) + [ CONFIGS_DIR, # Docs path @@ -168,27 +179,33 @@ def get_directory_allowlist_read(tmp_path: str) -> List[Union[str, Path]]: ) -def check_allowed_filewrite(path: Union[str, Path], tmp_path: str, mode: str) -> None: +def check_allowed_filewrite( + path: Union[str, Path], tmp_path: str, session_tmp_path: str, mode: str +) -> None: path = os.path.abspath(path) if mode == "r" or mode == "rb": assert any( [ Path(dir) in Path(path).parents or Path(dir) == Path(path) - for dir in get_directory_allowlist_read(tmp_path=tmp_path) + for dir in get_directory_allowlist_read( + tmp_path=tmp_path, session_tmp_path=session_tmp_path + ) ] ), f"Reading outside of allowed directories! {path}" else: assert any( [ Path(dir) in Path(path).parents or Path(dir) == Path(path) - for dir in get_directory_allowlist_write(tmp_path=tmp_path) + for dir in get_directory_allowlist_write( + tmp_path=tmp_path, session_tmp_path=session_tmp_path + ) ] ), f"Writing outside of allowed directories! {path}" @contextmanager def limit_filewrites( - monkeypatch: Any, tmp_path: str, allowed_access_enabled: bool = True + monkeypatch: Any, tmp_path: str, session_tmp_path: str, allowed_access_enabled: bool = True ) -> Iterator[bool]: import builtins @@ -208,7 +225,9 @@ def patched_open(name: Any, *args: Any, **kwargs: Any) -> Any: # Mode can be either in *args or **kwargs, if it's not, the default is "r" mode = "w" if "w" in args else "r" mode = kwargs.get("mode", mode) - check_allowed_filewrite(name, tmp_path=tmp_path, mode=mode) + check_allowed_filewrite( + name, tmp_path=tmp_path, session_tmp_path=session_tmp_path, mode=mode + ) return _open(name, *args, **kwargs) with monkeypatch.context() as m: @@ -217,24 +236,34 @@ def patched_open(name: Any, *args: Any, **kwargs: Any) -> Any: @pytest.fixture(scope="function", autouse=True) -def guard_filewrites(monkeypatch: Any, tmp_path: str) -> Iterator[bool]: +def guard_filewrites(monkeypatch: Any, tmp_path: str, session_tmp_path: str) -> Iterator[bool]: """Fixture which raises an exception if the filesystem is accessed outside of a limited set of allowed directories """ - with limit_filewrites(monkeypatch, tmp_path=tmp_path, allowed_access_enabled=True) as x: + with limit_filewrites( + monkeypatch, + tmp_path=tmp_path, + session_tmp_path=session_tmp_path, + allowed_access_enabled=True, + ) as x: yield x @pytest.fixture(scope="function") -def guard_all_filewrites(monkeypatch: Any, tmp_path: str) -> Iterator[bool]: +def guard_all_filewrites(monkeypatch: Any, tmp_path: str, session_tmp_path: str) -> Iterator[bool]: """Fixture which raises an exception if the filesystem is accessed.""" - with limit_filewrites(monkeypatch, tmp_path=tmp_path, allowed_access_enabled=False) as x: + with limit_filewrites( + monkeypatch, + tmp_path=tmp_path, + session_tmp_path=session_tmp_path, + allowed_access_enabled=False, + ) as x: yield x @contextmanager def suppress_network( - monkeypatch: Any, tmp_path: str, allowed_access_enabled: bool = True + monkeypatch: Any, tmp_path: str, session_tmp_path: str, allowed_access_enabled: bool = True ) -> Iterator[bool]: """ Context manager which raises an exception if network connection is requested. @@ -268,7 +297,9 @@ def guard_connect(*args: Any) -> Any: assert allowed_access_enabled, "Code is accessing network when it shouldn't have" addr = args[1] if isinstance(addr, str) or addr[0] == "127.0.0.1": - check_allowed_filewrite(addr, tmp_path=tmp_path, mode="w") + check_allowed_filewrite( + addr, tmp_path=tmp_path, session_tmp_path=session_tmp_path, mode="w" + ) return orig_fn(*args) # We must raise OSError (not Exception) similar to that raised # by socket.connect to support libraries that rely on this @@ -281,7 +312,7 @@ def guard_connect(*args: Any) -> Any: @pytest.fixture(scope="function") -def guard_network(monkeypatch: Any, tmp_path: str) -> Iterator[bool]: +def guard_network(monkeypatch: Any, tmp_path: str, session_tmp_path: str) -> Iterator[bool]: """ Fixture which raises an exception if the network is accessed. It will not raise an exception for localhost, use guard_all_network_access @@ -290,16 +321,28 @@ def guard_network(monkeypatch: Any, tmp_path: str) -> Iterator[bool]: Unit tests should not touch the network so this fixture helps guard against accidental network use. """ - with suppress_network(monkeypatch, tmp_path=tmp_path, allowed_access_enabled=True) as x: + with suppress_network( + monkeypatch, + tmp_path=tmp_path, + session_tmp_path=session_tmp_path, + allowed_access_enabled=True, + ) as x: yield x @pytest.fixture(scope="function") -def guard_all_network_access(monkeypatch: Any, tmp_path: str) -> Iterator[bool]: +def guard_all_network_access( + monkeypatch: Any, tmp_path: str, session_tmp_path: str +) -> Iterator[bool]: """Fixture which raises an exception if the network is accessed. Unit tests should not touch the network so this fixture helps guard against accidental network use. """ - with suppress_network(monkeypatch, tmp_path=tmp_path, allowed_access_enabled=False) as x: + with suppress_network( + monkeypatch, + tmp_path=tmp_path, + session_tmp_path=session_tmp_path, + allowed_access_enabled=False, + ) as x: yield x