diff --git a/domaintools/api.py b/domaintools/api.py index 8a00f01..4111a80 100644 --- a/domaintools/api.py +++ b/domaintools/api.py @@ -671,7 +671,6 @@ def iris_investigate( data_updated_after=None, expiration_date=None, create_date=None, - active=None, search_hash=None, risk_score=None, younger_than_date=None, @@ -701,9 +700,6 @@ def iris_investigate( if search_hash: kwargs["search_hash"] = search_hash - if not (kwargs or domains): - raise ValueError("Need to define investigation using kwarg filters or domains") - if isinstance(domains, (list, tuple)): domains = ",".join(domains) if hasattr(data_updated_after, "strftime"): @@ -712,8 +708,6 @@ def iris_investigate( expiration_date = expiration_date.strftime("%Y-%m-%d") if hasattr(create_date, "strftime"): create_date = create_date.strftime("%Y-%m-%d") - if isinstance(active, bool): - kwargs["active"] = str(active).lower() results = self._results( "iris-investigate", diff --git a/domaintools/decorators.py b/domaintools/decorators.py index 4640a8e..3c799b1 100644 --- a/domaintools/decorators.py +++ b/domaintools/decorators.py @@ -1,39 +1,81 @@ import functools +import inspect from typing import List, Union from domaintools.docstring_patcher import DocstringPatcher +from domaintools.request_validator import RequestValidator def api_endpoint(spec_name: str, path: str, methods: Union[str, List[str]]): """ - Decorator to tag a method as an API endpoint. + Decorator to tag a method as an API endpoint AND validate inputs. Args: spec_name: The key for the spec in api_instance.specs path: The API path (e.g., "/users") methods: A single method ("get") or list of methods (["get", "post"]) - that this function handles. """ def decorator(func): func._api_spec_name = spec_name func._api_path = path - # Always store the methods as a list - if isinstance(methods, str): - func._api_methods = [methods] - else: - func._api_methods = methods + # Normalize methods to a list + normalized_methods = [methods] if isinstance(methods, str) else methods + func._api_methods = normalized_methods + + # Get the signature of the original function ONCE + sig = inspect.signature(func) @functools.wraps(func) def wrapper(self, *args, **kwargs): + + try: + bound_args = sig.bind(*args, **kwargs) + except TypeError: + # If arguments don't match signature, let the actual func raise the error + return func(*args, **kwargs) + + arguments = bound_args.arguments + + # Robustly find 'self' (it's usually the first argument in bound_args) + # We look for the first value in arguments, or try to get 'self' explicitly. + instance = arguments.pop("self", None) + if not instance and args: + instance = args[0] + + # Retrieve the Spec from the instance + # We assume 'self' has a .specs attribute (like DocstringPatcher expects) + spec = getattr(self, "specs", {}).get(spec_name) + + if spec: + # Determine which HTTP method is currently being executed. + # If the function allows dynamic methods (e.g. method="POST"), use that. + # Otherwise, default to the first method defined in the decorator. + current_method = kwargs.get("method", normalized_methods[0]) + + # Run Validation + # This will raise a ValueError and stop execution if validation fails. + try: + RequestValidator.validate( + spec=spec, + path=path, + method=current_method, + parameters=arguments, + ) + except ValueError as e: + print(f"[Validation Error] {e}") + raise e + + # Proceed with the original function call return func(*args, **kwargs) - # Copy all tags to the wrapper + # Copy tags to wrapper for the DocstringPatcher to find wrapper._api_spec_name = func._api_spec_name wrapper._api_path = func._api_path wrapper._api_methods = func._api_methods + return wrapper return decorator diff --git a/domaintools/docstring_patcher.py b/domaintools/docstring_patcher.py index 45f4f54..a798adc 100644 --- a/domaintools/docstring_patcher.py +++ b/domaintools/docstring_patcher.py @@ -7,13 +7,6 @@ class DocstringPatcher: """ Patches docstrings for methods decorated with @api_endpoint. - - Uses the 'methods' list provided by the decorator. - - Finds non-standard parameters inside the 'requestBody' object. - - Displays Query Params, Request Body, and Result Body (Responses) - for all operations. - - Unpacks and displays properties of request body schemas. - - Searches components.parameters for request body properties - that match by name. """ def patch(self, api_instance): @@ -44,6 +37,7 @@ def patch(self, api_instance): path_item = spec_to_use.get("paths", {}).get(path, {}) for http_method in http_methods_to_check: if http_method.lower() in path_item: + # Helper is called via self, but it's an instance method calling a static method internally api_doc = self._generate_api_doc_string(spec_to_use, path, http_method) all_doc_sections.append(api_doc) @@ -66,77 +60,18 @@ def method_wrapper(*args, _orig_meth=original_method, **kwargs): method_wrapper.__get__(api_instance, api_instance.__class__), ) - def _generate_api_doc_string(self, spec: dict, path: str, method: str) -> str: - """Creates the formatted API docstring section for ONE operation.""" - - details = self._get_operation_details(spec, path, method) - lines = [f"--- Operation: {method.upper()} {path} ---"] - - lines.append(f"\n Summary: {details.get('summary', 'N/A')}") - lines.append(f" Description: {details.get('description', 'N/A')}") - lines.append(f" External Doc: {details.get('external_doc', 'N/A')}") - - # 1. Always display Query Parameters - lines.append("\n Query Parameters:") - if not details["query_params"]: - lines.append(" (No query parameters)") - else: - for param in details["query_params"]: - lines.append(f"\n **{param['name']}** ({param['type']})") - lines.append(f" Required: {param['required']}") - lines.append(f" Description: {param['description']}") - - # 2. Always display Request Body - lines.append("\n Request Body:") - if not details["request_body"]: - lines.append(" (No request body)") - else: - body = details["request_body"] - lines.append(f"\n **{body['type']}**") - lines.append(f" Required: {body['required']}") - lines.append(f" Description: {body['description']}") - - if body.get("properties"): - lines.append(f" Properties:") - for prop in body["properties"]: - lines.append(f"\n **{prop['name']}** ({prop['type']})") - lines.append(f" Description: {prop['description']}") - - if body.get("parameters"): - lines.append(f" Parameters (associated with this body):") - for param in body["parameters"]: - param_in = param.get("in", "N/A") - lines.append( - f"\n **{param['name']}** ({param['type']}) [in: {param_in}]" - ) - lines.append(f" Required: {param['required']}") - lines.append(f" Description: {param['description']}") - - # 3. Always display Result Body (Responses) - lines.append("\n Result Body (Responses):") - if not details["responses"]: - lines.append(" (No responses defined in spec)") - else: - for resp in details["responses"]: - lines.append(f"\n **{resp['status_code']}**: ({resp['type']})") - lines.append(f" Description: {resp['description']}") - - return "\n".join(lines) - - def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: + @staticmethod + def get_operation_details(spec: dict, path: str, method: str) -> dict: """ - Gets all details. Includes: - - Logic to find non-standard 'parameters' in 'requestBody' - - Logic to parse requestBody schema properties - - Logic to parse responses - - **NEW**: Logic to match requestBody properties to components/parameters + Gets all details for a specific operation. + Static method: Can be used without instantiating the class. + Usage: DocstringPatcher.get_operation_details(spec, '/users', 'post') """ details = {"query_params": [], "request_body": None, "responses": []} if not spec: return details try: - # --- Get component parameters for lookup --- components = spec.get("components", {}) all_component_params = components.get("parameters", {}) @@ -154,23 +89,22 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: resolved_body_def = {} if body_def: if "$ref" in body_def: - resolved_body_def = self._resolve_ref(spec, body_def["$ref"]) + resolved_body_def = DocstringPatcher._resolve_ref(spec, body_def["$ref"]) else: resolved_body_def = body_def body_level_params = resolved_body_def.get("parameters", []) all_param_defs = path_level_params + operation_level_params - # --- End Parameter Logic --- details["summary"] = operation.get("summary") details["description"] = operation.get("description") details["external_doc"] = operation.get("externalDocs", {}).get("url", "N/A") - # --- Query Param Processing (from path/operation only) --- + # --- Query Params --- resolved_params = [] for param_def in all_param_defs: if "$ref" in param_def: - resolved_params.append(self._resolve_ref(spec, param_def["$ref"])) + resolved_params.append(DocstringPatcher._resolve_ref(spec, param_def["$ref"])) else: resolved_params.append(param_def) @@ -180,12 +114,11 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: "name": p.get("name"), "required": p.get("required", False), "description": p.get("description", "N/A"), - "type": self._get_param_type(spec, p.get("schema")), + "type": DocstringPatcher._get_param_type(spec, p.get("schema")), } ) - # --- End Query Param Processing --- - # --- Request Body Processing --- + # --- Request Body --- if body_def: content = resolved_body_def.get("content", {}) media_type = next(iter(content.values()), None) @@ -194,7 +127,7 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: if media_type and "schema" in media_type: schema = media_type["schema"] - schema_type = self._get_param_type(spec, schema) + schema_type = DocstringPatcher._get_param_type(spec, schema) details["request_body"] = { "required": resolved_body_def.get("required", False), @@ -204,23 +137,27 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: "properties": [], } - # --- Process schema properties with new lookup logic --- + # --- Schema Properties --- resolved_schema = {} - if "$ref" in schema: - resolved_schema = self._resolve_ref(spec, schema["$ref"]) - elif schema.get("type") == "object": - resolved_schema = schema + current_schema_for_props = schema + while "$ref" in current_schema_for_props or "$ref:" in current_schema_for_props: + ref = current_schema_for_props.get("$ref") or current_schema_for_props.get( + "$ref:" + ) + if not ref: + break + current_schema_for_props = DocstringPatcher._resolve_ref(spec, ref) - if resolved_schema.get("type") == "object" and "properties" in resolved_schema: - for prop_name, prop_def in resolved_schema["properties"].items(): + if ( + current_schema_for_props.get("type") == "object" + and "properties" in current_schema_for_props + ): + for prop_name, prop_def in current_schema_for_props["properties"].items(): found_param_match = False - # --- Try to find a match in components/parameters --- - # (Iterate over values, e.g., the LimitParam object) for component_param_def in all_component_params.values(): if component_param_def.get("name") == prop_name: - # Found a match! Use its details. - prop_type = self._get_param_type( + prop_type = DocstringPatcher._get_param_type( spec, component_param_def.get("schema") ) prop_desc = component_param_def.get("description", "N/A") @@ -231,18 +168,19 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: break if not found_param_match: - # No match, process as a normal schema property - prop_type = self._get_param_type(spec, prop_def) + prop_type = DocstringPatcher._get_param_type(spec, prop_def) prop_desc = prop_def.get("description", "N/A") details["request_body"]["properties"].append( {"name": prop_name, "type": prop_type, "description": prop_desc} ) - # --- Body Parameter Processing (for non-standard spec) --- + # --- Body Parameters --- resolved_body_params = [] for param_def in body_level_params: if "$ref" in param_def: - resolved_body_params.append(self._resolve_ref(spec, param_def["$ref"])) + resolved_body_params.append( + DocstringPatcher._resolve_ref(spec, param_def["$ref"]) + ) else: resolved_body_params.append(param_def) @@ -253,17 +191,16 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: "in": p.get("in"), "required": p.get("required", False), "description": p.get("description", "N/A"), - "type": self._get_param_type(spec, p.get("schema")), + "type": DocstringPatcher._get_param_type(spec, p.get("schema")), } ) - # --- End Request Body Processing --- - # --- Response Processing Logic --- + # --- Responses --- responses_def = operation.get("responses", {}) for status_code, resp_def in responses_def.items(): resolved_resp = {} if "$ref" in resp_def: - resolved_resp = self._resolve_ref(spec, resp_def["$ref"]) + resolved_resp = DocstringPatcher._resolve_ref(spec, resp_def["$ref"]) else: resolved_resp = resp_def @@ -274,7 +211,7 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: if media_type and "schema" in media_type: schema = media_type["schema"] - resp_type = self._get_param_type(spec, schema) + resp_type = DocstringPatcher._get_param_type(spec, schema) details["responses"].append( { @@ -283,14 +220,14 @@ def _get_operation_details(self, spec: dict, path: str, method: str) -> dict: "type": resp_type, } ) - # --- END: Response Processing Logic --- return details except Exception as e: logging.warning(f"Error parsing spec for {method.upper()} {path}: {e}", exc_info=True) return details - def _resolve_ref(self, spec: dict, ref: str): + @staticmethod + def _resolve_ref(spec: dict, ref: str): """Resolves a JSON schema $ref string.""" if not spec or not ref.startswith("#/"): return {} @@ -310,25 +247,99 @@ def _resolve_ref(self, spec: dict, ref: str): return {} return current_obj - def _get_param_type(self, spec: dict, schema: dict) -> str: - """Gets a display-friendly type name from a schema object.""" + @staticmethod + def _get_param_type(spec: dict, schema: dict) -> str: + """ + Gets the type name. Handles recursion and arrays. + """ if not schema: return "N/A" - # Check for malformed refs (like in your example spec) - schema_ref = schema.get("$ref") - if not schema_ref: - # Handle user's typo: "$ref:" - schema_ref = schema.get("$ref:") + current_schema = schema + ref_name = None + + while True: + ref_string = current_schema.get("$ref") or current_schema.get("$ref:") + + if not ref_string: + break + + ref_name = ref_string.split("/")[-1] - if schema_ref: - return schema_ref.split("/")[-1] + resolved = DocstringPatcher._resolve_ref(spec, ref_string) - schema_type = schema.get("type", "N/A") + if not resolved: + return ref_name or "N/A" + + if "schema" in resolved: + current_schema = resolved["schema"] + else: + current_schema = resolved + + schema_type = current_schema.get("type", "N/A") if schema_type == "array": - items_schema = schema.get("items", {}) - items_type = self._get_param_type(spec, items_schema) + items_schema = current_schema.get("items", {}) + items_type = DocstringPatcher._get_param_type(spec, items_schema) return f"array[{items_type}]" + if schema_type == "object" and ref_name: + return ref_name + return schema_type + + def _generate_api_doc_string(self, spec: dict, path: str, method: str) -> str: + """Creates the formatted API docstring section for ONE operation.""" + + # Call static method + details = self.get_operation_details(spec, path, method) + + lines = [f"--- Operation: {method.upper()} {path} ---"] + + lines.append(f"\n Summary: {details.get('summary', 'N/A')}") + lines.append(f" Description: {details.get('description', 'N/A')}") + lines.append(f" External Doc: {details.get('external_doc', 'N/A')}") + + lines.append("\n Query Parameters:") + if not details["query_params"]: + lines.append(" (No query parameters)") + else: + for param in details["query_params"]: + lines.append(f"\n **{param['name']}** ({param['type']})") + lines.append(f" Required: {param['required']}") + lines.append(f" Description: {param['description']}") + + lines.append("\n Request Body:") + if not details["request_body"]: + lines.append(" (No request body)") + else: + body = details["request_body"] + lines.append(f"\n **{body['type']}**") + lines.append(f" Required: {body['required']}") + lines.append(f" Description: {body['description']}") + + if body.get("properties"): + lines.append(f" Properties:") + for prop in body["properties"]: + lines.append(f"\n **{prop['name']}** ({prop['type']})") + lines.append(f" Description: {prop['description']}") + + if body.get("parameters"): + lines.append(f" Parameters (associated with this body):") + for param in body["parameters"]: + param_in = param.get("in", "N/A") + lines.append( + f"\n **{param['name']}** ({param['type']}) [in: {param_in}]" + ) + lines.append(f" Required: {param['required']}") + lines.append(f" Description: {param['description']}") + + lines.append("\n Result Body (Responses):") + if not details["responses"]: + lines.append(" (No responses defined in spec)") + else: + for resp in details["responses"]: + lines.append(f"\n **{resp['status_code']}**: ({resp['type']})") + lines.append(f" Description: {resp['description']}") + + return "\n".join(lines) diff --git a/domaintools/request_validator.py b/domaintools/request_validator.py new file mode 100644 index 0000000..88286b1 --- /dev/null +++ b/domaintools/request_validator.py @@ -0,0 +1,118 @@ +from domaintools.docstring_patcher import DocstringPatcher + + +class RequestValidator: + """ + Validates user input against the OpenAPI spec using DocstringPatcher. + Separates validation logic based on HTTP verbs. + """ + + TYPE_MAP = { + "string": str, + "integer": int, + "number": (int, float), + "boolean": bool, + "array": list, + "object": dict, + } + + @staticmethod + def validate( + spec: dict, + path: str, + method: str, + parameters: dict = None, + ): + """ + Orchestrator: Decides which validation to run based on the HTTP method. + """ + method = method.upper() + + # GET requests: Validate Query Parameters only + if method == "GET": + RequestValidator.validate_query_params(spec, path, method, parameters) + + # POST/PUT/PATCH: Validate Request Body + elif method in ["POST", "PUT", "PATCH"]: + RequestValidator.validate_body(spec, path, method, parameters) + + return True + + @staticmethod + def validate_query_params(spec: dict, path: str, method: str, q_params: dict): + """ + Validates ONLY the query parameters. + """ + q_params = q_params or {} + details = DocstringPatcher.get_operation_details(spec, path, method) + errors = [] + + for param in details["query_params"]: + param_name = param["name"] + is_required = param["required"] + param_type = param["type"] + + # Check existence + if is_required and param_name not in q_params: + errors.append(f"Missing required query parameter: '{param_name}'") + continue + + # Check Type (only if present) + if param_name in q_params: + val = q_params[param_name] + RequestValidator._check_type(val, param_type, f"query.{param_name}", errors) + + if errors: + raise ValueError("Query Parameter Validation Failed:\n - " + "\n - ".join(errors)) + + @staticmethod + def validate_body(spec: dict, path: str, method: str, body_data: dict): + """ + Validates ONLY the request body. + """ + details = DocstringPatcher.get_operation_details(spec, path, method) + errors = [] + + if not details["request_body"]: + # If spec has no body defined, but user sent one, you might want to warn + # or simply ignore. We will ignore here. + return + + body_rules = details["request_body"] + + # Check Body Existence + if not body_data: + raise ValueError("Validation Failed: Missing required request body.") + + # Check Body Properties + # Only proceed if we have data and the spec defines properties + if body_data and body_rules.get("properties"): + for prop in body_rules["properties"]: + p_name = prop["name"] + p_type = prop["type"] + + # Check Type if the property exists in the user input + if p_name in body_data: + RequestValidator._check_type( + body_data[p_name], p_type, f"body.{p_name}", errors + ) + + if errors: + raise ValueError("Body Validation Failed:\n - " + "\n - ".join(errors)) + + @staticmethod + def _check_type(value, openapi_type, field_name, errors): + """Helper to check python types against OpenAPI string types.""" + simple_type = openapi_type + + if "array" in openapi_type: + simple_type = "array" + elif openapi_type not in RequestValidator.TYPE_MAP: + return + + expected_type = RequestValidator.TYPE_MAP.get(simple_type) + + if expected_type and not isinstance(value, expected_type): + errors.append( + f"Invalid type for '{field_name}'. Expected {simple_type}, got {type(value).__name__}." + ) diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..43a05c5 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,219 @@ +import pytest +from unittest.mock import Mock, patch +from domaintools.decorators import api_endpoint + + +@pytest.fixture +def api_specs(): + """ + A fixture that acts as a central registry for all test OpenAPI specs. + """ + return { + # --- Spec V1: Standard --- + "v1": { + "openapi": "3.0.0", + "info": {"title": "Standard Spec", "version": "1.0.0"}, + "paths": { + "/users": { + "get": { + "parameters": [ + { + "name": "status", + "in": "query", + "required": True, + "schema": {"type": "string"}, + } + ] + }, + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + } + }, + } + }, + } + }, + }, + # --- Spec V3: Complex Lookup (Matches parameter name in components) --- + "v3_complex": { + "openapi": "3.0.0", + "info": {"title": "Complex Lookup Spec", "version": "3.0.0"}, + "components": { + "parameters": { + "LimitParam": { + "name": "limit", + "in": "query", + "description": "Max number of items.", + "schema": {"type": "integer"}, + } + }, + "schemas": { + "UserRequestParameters": { + "type": "object", + "properties": { + "limit": { + # The validator/patcher should match this name to 'LimitParam' above + "$ref:": "#/components/schemas/IgnoredRef", + }, + }, + }, + }, + "requestBodies": { + "UserBody": { + "required": True, + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/UserRequestParameters"}, + }, + }, + } + }, + }, + "paths": { + "/users": { + "post": { + "requestBody": {"$ref": "#/components/requestBodies/UserBody"}, + }, + }, + }, + }, + } + + +@pytest.fixture +def mock_client(api_specs): + """ + Creates a mock API client that is pre-patched with the specs defined above. + """ + client = Mock() + client.specs = api_specs + return client + + +class TestApiEndpointDecorator: + + def test_metadata_preservation(self, mock_client): + """ + Ensure decorator copies metadata for DocstringPatcher. + """ + + @api_endpoint(spec_name="v1", path="/users", methods="GET") + def get_users(): + """Original Docstring""" + pass + + bound_method = get_users.__get__(mock_client, Mock) + + assert bound_method._api_spec_name == "v1" + assert bound_method._api_path == "/users" + assert bound_method._api_methods == ["GET"] + assert bound_method.__doc__ == "Original Docstring" + + def test_valid_post_request(self, mock_client): + """ + Test a valid POST request against 'v1' spec. + """ + + @api_endpoint(spec_name="v1", path="/users", methods="POST") + def create_user(request_body=None): + return "Created" + + # Mocking validate to ensure arguments are passed correctly, + # but we could also let it run against the real logic if we wanted integration tests. + with patch("domaintools.request_validator.RequestValidator.validate") as mock_validate: + result = create_user(mock_client, request_body={"name": "Alice", "age": 30}) + + assert result == "Created" + + # Check arguments passed to validator + call_kwargs = mock_validate.call_args[1] + assert call_kwargs["spec"] == mock_client.specs["v1"] + assert call_kwargs.get("parameters", {}).get("request_body") == { + "name": "Alice", + "age": 30, + } + + def test_validation_failure_blocks_execution(self, mock_client): + """ + Test that if validation fails, the function doesn't run. + """ + inner_logic = Mock() + + @api_endpoint(spec_name="v1", path="/users", methods="POST") + def create_user(body=None): + inner_logic() + + # Simulate a validation error + with patch( + "domaintools.request_validator.RequestValidator.validate", + side_effect=ValueError("Bad Input"), + ): + with pytest.raises(ValueError, match="Bad Input"): + create_user(mock_client, body={"bad": "data"}) + + inner_logic.assert_not_called() + + def test_complex_spec_lookup_integration(self, mock_client): + """ + Test that the decorator works with the complex 'v3_complex' spec + we defined in the fixture. + """ + + @api_endpoint(spec_name="v3_complex", path="/users", methods="POST") + def create_user_complex(body=None): + return "Complex Success" + + with patch("domaintools.request_validator.RequestValidator.validate") as mock_validate: + create_user_complex(mock_client, body={"limit": 10}) + + # Verify the correct spec dictionary was retrieved and passed + call_kwargs = mock_validate.call_args[1] + assert call_kwargs["spec"]["info"]["title"] == "Complex Lookup Spec" + assert call_kwargs["path"] == "/users" + + def test_missing_spec_skips_validation(self, mock_client): + """ + If we ask for a spec name that isn't in the fixture, it should handle gracefully. + """ + + @api_endpoint(spec_name="non_existent_version", path="/users", methods="GET") + def get_users(): + return "Ran Safe" + + with patch("domaintools.request_validator.RequestValidator.validate") as mock_validate: + result = get_users(mock_client) + + assert result == "Ran Safe" + mock_validate.assert_not_called() + + def test_positional_arguments_are_mapped(self, mock_client): + """ + Test that passing arguments positionally (args) instead of via keywords (kwargs) + still triggers validation correctly. + """ + + # Define function with explicit parameter names + @api_endpoint(spec_name="v1", path="/users", methods="POST") + def create_user(name=None, body=None): + return "Success" + + with patch("domaintools.request_validator.RequestValidator.validate") as mock_validate: + # CALL POSITIONALLY: passing client and body as args + # (Note: we pass mock_client manually because create_user is just a function here) + create_user(mock_client, "test-name") + + # Verify validator received the data mapped to 'body_data' + mock_validate.assert_called_once() + call_kwargs = mock_validate.call_args[1] + + assert call_kwargs.get("parameters") == {"name": "test-name"} diff --git a/tests/test_request_validator.py b/tests/test_request_validator.py new file mode 100644 index 0000000..7106387 --- /dev/null +++ b/tests/test_request_validator.py @@ -0,0 +1,234 @@ +import pytest +from unittest.mock import patch, MagicMock + +# Assuming your class is in api_core.py +from domaintools.request_validator import RequestValidator + + +class TestRequestValidator: + + # ========================================================================= + # 1. QUERY PARAMETER TESTS (GET) + # ========================================================================= + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_query_param_success(self, mock_get_details): + """Test that correct query parameters pass validation.""" + # Setup the "Rules" + mock_get_details.return_value = { + "query_params": [ + {"name": "page", "required": True, "type": "integer"}, + {"name": "sort", "required": False, "type": "string"}, + ], + "request_body": None, + "responses": [], + } + + # Case 1: All valid + RequestValidator.validate( + spec={}, path="/test", method="GET", parameters={"page": 1, "sort": "asc"} + ) + + # Case 2: Optional missing (should still pass) + RequestValidator.validate(spec={}, path="/test", method="GET", parameters={"page": 5}) + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_query_param_missing_required(self, mock_get_details): + """Test failure when a required query param is missing.""" + mock_get_details.return_value = { + "query_params": [{"name": "id", "required": True, "type": "integer"}], + "request_body": None, + } + + with pytest.raises(ValueError) as exc: + RequestValidator.validate( + spec={}, path="/test", method="GET", parameters={} # 'id' is missing + ) + assert "Missing required query parameter: 'id'" in str(exc.value) + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_query_param_wrong_type(self, mock_get_details): + """Test failure when query param has wrong python type.""" + mock_get_details.return_value = { + "query_params": [{"name": "limit", "required": True, "type": "integer"}], + "request_body": None, + } + + with pytest.raises(ValueError) as exc: + RequestValidator.validate( + spec={}, + path="/test", + method="GET", + parameters={"limit": "10"}, # String instead of Int + ) + assert "Invalid type for 'query.limit'" in str(exc.value) + assert "Expected integer" in str(exc.value) + + # ========================================================================= + # 2. REQUEST BODY TESTS (POST/PUT/PATCH) + # ========================================================================= + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_body_success(self, mock_get_details): + """Test that a valid body passes.""" + mock_get_details.return_value = { + "query_params": [], + "request_body": { + "required": True, + "type": "UserRequest", + "properties": [ + {"name": "username", "type": "string"}, + {"name": "age", "type": "integer"}, + ], + }, + } + + RequestValidator.validate( + spec={}, path="/users", method="POST", parameters={"username": "alice", "age": 30} + ) + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_body_missing_required_body(self, mock_get_details): + """Test failure when the entire body is required but missing.""" + mock_get_details.return_value = { + "query_params": [], + "request_body": {"required": True, "properties": []}, # <--- Body is required + } + + with pytest.raises(ValueError) as exc: + RequestValidator.validate( + spec={}, path="/users", method="POST", parameters=None # User sent nothing + ) + assert "Missing required request body" in str(exc.value) + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_body_property_wrong_type(self, mock_get_details): + """Test failure when a specific body property has the wrong type.""" + mock_get_details.return_value = { + "query_params": [], + "request_body": { + "required": True, + "properties": [{"name": "is_active", "type": "boolean"}], + }, + } + + with pytest.raises(ValueError) as exc: + RequestValidator.validate( + spec={}, + path="/users", + method="POST", + parameters={"is_active": "yes"}, # String instead of Bool + ) + assert "Invalid type for 'body.is_active'" in str(exc.value) + assert "Expected boolean" in str(exc.value) + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_validate_body_extra_fields_allowed(self, mock_get_details): + """ + Test that extra fields in the body NOT defined in spec are ignored + (standard behavior unless additionalProperties: false is strictly enforced). + """ + mock_get_details.return_value = { + "query_params": [], + "request_body": {"required": True, "properties": [{"name": "name", "type": "string"}]}, + } + + # User sends 'name' AND 'extra_field' + RequestValidator.validate( + spec={}, path="/users", method="POST", parameters={"name": "Alice", "extra_field": 123} + ) + + # ========================================================================= + # 3. TYPE CHECKING EDGE CASES + # ========================================================================= + + @pytest.mark.parametrize( + "openapi_type, valid_value, invalid_value", + [ + ("integer", 10, "10"), + ("string", "hello", 123), + ("boolean", True, "True"), + ("number", 10.5, "10.5"), # number allows float or int + ("number", 10, "10"), + ("array", [1, 2], {"a": 1}), + ("object", {"a": 1}, [1, 2]), + ("array[string]", ["a", "b"], "a string"), # simplified array check + ], + ) + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_all_data_types(self, mock_get_details, openapi_type, valid_value, invalid_value): + """ + Parametrized test to cover all supported primitive types in _check_type. + """ + mock_get_details.return_value = { + "query_params": [], + "request_body": { + "required": True, + "properties": [{"name": "test_field", "type": openapi_type}], + }, + } + + # 1. Test Valid Value + RequestValidator.validate( + spec={}, path="/", method="POST", parameters={"test_field": valid_value} + ) + + # 2. Test Invalid Value + with pytest.raises(ValueError) as exc: + RequestValidator.validate( + spec={}, path="/", method="POST", parameters={"test_field": invalid_value} + ) + assert f"Invalid type for 'body.test_field'" in str(exc.value) + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_unknown_or_complex_types_skipped(self, mock_get_details): + """ + Test that if the type is complex (e.g., 'User') or unknown, + validation is skipped (does not raise error). + """ + mock_get_details.return_value = { + "query_params": [], + "request_body": { + "required": True, + "properties": [ + {"name": "complex_obj", "type": "UserDefinition"}, # Not a primitive + ], + }, + } + + # Should pass regardless of what we put in (since we can't validate "UserDefinition" easily) + RequestValidator.validate( + spec={}, path="/", method="POST", parameters={"complex_obj": {"any": "structure"}} + ) + + # ========================================================================= + # 4. METHOD ORCHESTRATION + # ========================================================================= + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_ignore_body_on_get(self, mock_get_details): + """Test that body data is ignored if method is GET.""" + # Setup: GET request defined, but no body info + mock_get_details.return_value = { + "query_params": [], + "request_body": None, # Spec says no body + } + + # We send body data anyway + RequestValidator.validate( + spec={}, path="/", method="GET", parameters={"should": "be_ignored"} + ) + # Should pass without error + + @patch("domaintools.docstring_patcher.DocstringPatcher.get_operation_details") + def test_ignore_query_on_post(self, mock_get_details): + """Test that query params are ignored (not validated) if method is POST.""" + # Spec says 'id' is required if we were looking at query params + mock_get_details.return_value = { + "query_params": [{"name": "id", "required": True, "type": "integer"}], + "request_body": None, + } + + # We call POST, passing NO query params. + # Since POST logic only checks body, this should NOT complain about missing 'id'. + RequestValidator.validate(spec={}, path="/", method="POST", parameters={}) diff --git a/tox.ini b/tox.ini index 17dd5ca..6129ba6 100644 --- a/tox.ini +++ b/tox.ini @@ -11,11 +11,11 @@ passenv = TEST_KEY deps = click==8.1.8 + vcrpy==7.0.0 pytest pytest-cov pytest-asyncio httpx - vcrpy rich typer .