diff --git a/create_hashes.py b/create_hashes.py new file mode 100644 index 00000000..f01f898c --- /dev/null +++ b/create_hashes.py @@ -0,0 +1,16 @@ +from passlib.context import CryptContext + +pwd_context = CryptContext(schemes=['argon2'], deprecated='auto') + + +users = { + 'testuser': 'testuser', +} + +env_content = '' +print('\nCopy the following lines into your .env file:\n') +for username, plain_password in users.items(): + hashed_password = pwd_context.hash(plain_password) + env_line = f"USER_{username.upper()}_HASH=\"{hashed_password}\"" + print(env_line) + env_content += env_line + '\n' diff --git a/docs/source/configuration/command_config.rst b/docs/source/configuration/command_config.rst index 231b4a5d..2d01a2e6 100644 --- a/docs/source/configuration/command_config.rst +++ b/docs/source/configuration/command_config.rst @@ -2,13 +2,15 @@ cmd_config ========== -Stores global variables for command options. These are settings for **all** commands. +Stores global variables for command options. +These are settings for **all** commands. .. code-block:: yaml ### cmd_config: loop_sleep: 5 + command_delay: 0 .. confval:: loop_sleep @@ -19,3 +21,11 @@ Stores global variables for command options. These are settings for **all** comm :type: int :default: 5 + +.. confval:: command_delay + + This delay in seconds is applied to all commands in the playbook. + It is not applied to debug, setvar and sleep commands. + + :type: float + :default: 0 diff --git a/docs/source/configuration/index.rst b/docs/source/configuration/index.rst index a40380ca..e2fe41f5 100644 --- a/docs/source/configuration/index.rst +++ b/docs/source/configuration/index.rst @@ -25,6 +25,7 @@ sliver and metasploit: ### cmd_config: loop_sleep: 5 + command_delay: 0 msf_config: password: securepassword diff --git a/docs/source/developing/architecture.rst b/docs/source/developing/architecture.rst new file mode 100644 index 00000000..5da12c27 --- /dev/null +++ b/docs/source/developing/architecture.rst @@ -0,0 +1,35 @@ +======================= +System Architecture (C4) +======================= + +This section presents the architecture of AttackMate using the +`C4 model `_, a visual framework for describing software architecture across different levels of detail. + +C1 – System Context Diagram +--------------------------- + +.. image:: ../images/AttackMate-C1.png + :width: 80% + :alt: System Context Diagram + +The System Context diagram shows how **AttackMate** fits into its environment. It illustrates the main user +(e.g., a pentester or researcher), the software systems it interacts with (e.g., vulnerable target systems, external +frameworks like Metasploit or Sliver), and the nature of those interactions. + + +C2 – Container Diagram +---------------------- + +This diagram shows how AttackMate is internally structured as a modular Python application. + +.. image:: ../images/AttackMate-C2.png + :alt: Container Diagram + +The system is centered around a core orchestration class that receives parsed playbook commands and delegates their +execution to appropriate components. It separates concerns between parsing, background task management, session handling, +and command execution, which makes it easy to extend with new command types or external tool integrations. + +Future diagrams (e.g., C3 or C4) could describe class-level and code-level structures if needed. + +.. note:: + The official C4 model site (https://c4model.com) provides detailed guidance if you're unfamiliar with this approach. diff --git a/docs/source/developing/baseexecutor.rst b/docs/source/developing/baseexecutor.rst index d6cc799c..647ae9fd 100644 --- a/docs/source/developing/baseexecutor.rst +++ b/docs/source/developing/baseexecutor.rst @@ -6,7 +6,7 @@ Adding a New Executor Base Executor ================ -The ``BaseExecutor`` is the core class from which all executors in AttackMate inherit. +The ``BaseExecutor`` is the core class from which all executors in AttackMate inherit. It provides a structured approach to implementing custom executors. Key Features @@ -64,8 +64,8 @@ Overridable Methods The following methods can be overridden in custom executors to modify behavior: -**Command Execution** - +**Command Execution** + .. code-block:: python def _exec_cmd(self, command: BaseCommand) -> Result: @@ -74,24 +74,28 @@ The following methods can be overridden in custom executors to modify behavior: This is the core execution function and must be implemented in subclasses. It should return a ``Result`` object containing the execution outcome. -.. note:: +.. note:: - The ``_exec_cmd()`` method **must** be implemented in any subclass of ``BaseExecutor``. - This method defines the core execution logic for the command and is responsible for returning a ``Result`` object. + The ``_exec_cmd()`` method **must** be implemented in any subclass of ``BaseExecutor``. + This method defines the core execution logic for the command and is responsible for returning a ``Result`` object. -**Logging Functions** +**Logging Functions** The methods ``log_command``, ``log_matadata`` and ``log_json`` log command execution details and can be overridden for custom logging formats. -**Command Execution Flow** +**Command Execution Flow** The ``run()`` method defines the high-level execution flow of a command. It includes condition checking, logging, and calling the actual execution logic. -**Output Handling** +**Output Handling** The ``save_output()`` function manages saving output to a file. It can be overridden to implement alternative storage methods. +executor __init__.py +-------------------- +.. note:: + Add the new executor to the ``__all__`` list in the ``__init__.py`` file of the ``attackmate.executors`` module. diff --git a/docs/source/images/AttackMate-C1.png b/docs/source/images/AttackMate-C1.png new file mode 100644 index 00000000..c586bc0a Binary files /dev/null and b/docs/source/images/AttackMate-C1.png differ diff --git a/docs/source/images/AttackMate-C2.png b/docs/source/images/AttackMate-C2.png new file mode 100644 index 00000000..f637ca66 Binary files /dev/null and b/docs/source/images/AttackMate-C2.png differ diff --git a/docs/source/images/AttackMate-C4.drawio b/docs/source/images/AttackMate-C4.drawio new file mode 100644 index 00000000..b1764cf0 --- /dev/null +++ b/docs/source/images/AttackMate-C4.drawio @@ -0,0 +1,267 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/index.rst b/docs/source/index.rst index ec5598b0..6bd20055 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,6 +37,7 @@ Welcome to AttackMate's documentation! :maxdepth: 1 :caption: Developing: + developing/architecture developing/command developing/baseexecutor developing/integration diff --git a/docs/source/playbook/commands/index.rst b/docs/source/playbook/commands/index.rst index bb1b51c9..bc1fa4fa 100644 --- a/docs/source/playbook/commands/index.rst +++ b/docs/source/playbook/commands/index.rst @@ -167,6 +167,8 @@ Every command, regardless of the type has the following general options: * MsfModuleCommand * IncludeCommand + * VncCommand + * BrowserCommand Background-Mode together with a session is currently not implemented for the following commands: diff --git a/pyproject.toml b/pyproject.toml index ed06bb79..4bc60793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,13 @@ dependencies = [ "httpx[http2]", "vncdotool", "pytest-mock", - "playwright" + "fastapi", + "playwright", + "argon2-cffi", + "uvicorn", + "dotenv", + "passlib", + "python-multipart" ] dynamic = ["version"] diff --git a/remote_rest/README.md b/remote_rest/README.md new file mode 100644 index 00000000..507053b0 --- /dev/null +++ b/remote_rest/README.md @@ -0,0 +1,76 @@ +pip install fastapi uvicorn httpx PyYAML pydantic argon2_cffi + + +uvicorn remote_rest.main:app --host 0.0.0.0 --port 8000 --reload + +[] TODO sort out logs for different instances + +[] TODO return logs to caller + +[] TODO limit max. concurent instance number + +[] TODO concurrency for several instances? + +[] TODO add authentication + +[] TODO queue requests for instances + +[] TODO dynamic configuration of attackmate config + +[] TODO make logging (debug, json etc) configurable at runtime (endpoint or user query paramaters?) + +[] TODO ALLOWED_PLAYBOOK_DIR -> define in and load from configs + +[] TODO add swagger examples + +[] TODO generate/check OpenAPI schema + +[x] TODO seperate router modules? + + + + + +# Execute a playbook by sending its YAML content (uses a temporary instance) +python -m remote_rest.client playbook-yaml examples/playbook.yml + +# Request the server execute a playbook from its allowed directory + Ensure 'playbook.yml' exists in server's ALLOWED_PLAYBOOK_DIR + +python -m remote_rest.client playbook-file safe_playbook.yml + + +# Single Command Execution (on a persistent Instance) + +## Shell Command +```bash +python -m remote_rest.client command shell 'echo "Hello"' +``` + +### Run a command in the background or with metadata + ```bash + python -m remote_rest.client command shell 'echo hello' --metadata tactic=recon --metadata technique=TXXX + ``` + +## Run a Sleep Command (Background): + ```bash + python -m remote_rest.client command sleep --seconds 8 --background + ``` + # (Client returns immediately, server sleeps) + + +# Certificate generation +preliminary, automate later? +with open ssl + ```bash + openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes + ``` + +Common Name: localhost (or ip adress the server will be) + + +running client: + +```bash +python -m client --cacert login user user +``` diff --git a/remote_rest/__init__.py b/remote_rest/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/remote_rest/auth_utils.py b/remote_rest/auth_utils.py new file mode 100644 index 00000000..cba7d8f0 --- /dev/null +++ b/remote_rest/auth_utils.py @@ -0,0 +1,103 @@ +import logging +import os +import secrets +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional + +from dotenv import load_dotenv +from fastapi import Depends, HTTPException, status +from fastapi.security import APIKeyHeader +from passlib.context import CryptContext + +load_dotenv() + + +TOKEN_EXPIRE_MINUTES = int(os.getenv('TOKEN_EXPIRE_MINUTES', 30)) +API_KEY_HEADER_NAME = 'X-Auth-Token' +api_key_header_scheme = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=True) +pwd_context = CryptContext(schemes=['argon2'], deprecated='auto') + +# In-Memory token Store +# token looks like this token : {"username": str, "expires": datetime} +# state is lost on server restart. +# Not inherently thread-safe for multi-worker setups without locks ? +ACTIVE_TOKENS: Dict[str, Dict[str, Any]] = {} + +logger = logging.getLogger(__name__) + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +def get_user_hash(username: str) -> Optional[str]: + """Fetches the hashed password from environment variables.""" + env_var_name = f"USER_{username.upper()}_HASH" + return os.getenv(env_var_name) + + +def create_access_token(username: str) -> str: + """Creates a new token, stores it, and returns the token string.""" + token = secrets.token_urlsafe(32) + expires = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES) + # TODO locking needed for multi-threaded access, smth like with token_lock ? + ACTIVE_TOKENS[token] = {'username': username, 'expires': expires} + logger.info(f"Created new token for user '{username}' expiring at {expires}") + return token + + +def renew_token_expiry(token: str) -> bool: + """Updates the expiry time for an existing token. Returns True if successful.""" + token_data = ACTIVE_TOKENS.get(token) + if token_data: + token_data['expires'] = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES) + logger.debug(f"Renewed token expiry for user '{token_data['username']}'") + return True + return False + + +def cleanup_expired_tokens(): + """Removes expired tokens from the store""" + now = datetime.now(timezone.utc) + expired_tokens = [token for token, data in ACTIVE_TOKENS.items() if data['expires'] < now] + for token in expired_tokens: + username = ACTIVE_TOKENS.get(token, {}).get('username', 'unknown') + del ACTIVE_TOKENS[token] + logger.info(f"Removed expired token for user '{username}'.") + + +# Authentication Dependency -> this gets passed to the routes +async def get_current_user(token: str = Depends(api_key_header_scheme)) -> str: + """ + validate token and return the username + renews the token's expiration on successful validation + cleanup of expired tokens. + """ + + cleanup_expired_tokens() + + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Invalid authentication credentials', + headers={'WWW-Authenticate': 'Bearer'}, + ) + + token_data = ACTIVE_TOKENS.get(token) + if not token_data: + logger.warning(f"Token not found: {token[:5]}...") + raise credentials_exception + + username: str = token_data['username'] + expires: datetime = token_data['expires'] + + if expires < datetime.now(timezone.utc): + logger.warning(f"Token expired for user '{username}'") + # Remove the expired token + if token in ACTIVE_TOKENS: + del ACTIVE_TOKENS[token] + raise credentials_exception + + renew_token_expiry(token) + + logger.debug(f"Token validated successfully for user: {username}") + return username diff --git a/remote_rest/client.py b/remote_rest/client.py new file mode 100644 index 00000000..c3091041 --- /dev/null +++ b/remote_rest/client.py @@ -0,0 +1,403 @@ +import argparse +import json +import logging +import os +import sys +from typing import Any, Dict, List, Optional + +import httpx +import yaml + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - CLIENT - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +# Authentication +CURRENT_TOKEN: Optional[str] = None +# token can then be saved in env var since it does not persit in client memory +TOKEN_ENV_VAR = 'ATTACKMATE_API_TOKEN' + + +def load_token(): + """Loads token from global or env""" + global CURRENT_TOKEN + if CURRENT_TOKEN: + return CURRENT_TOKEN + CURRENT_TOKEN = os.getenv(TOKEN_ENV_VAR) + if CURRENT_TOKEN: + logger.info('Token loaded') + return CURRENT_TOKEN + + +def save_token(token: Optional[str]): + """Saves token to global var and env""" + global CURRENT_TOKEN + CURRENT_TOKEN = token + if token: + # This is pretty hacky, client mainly for testing purposes + logger.info('updating env var') + logger.info(f"run in your shell: export ATTACKMATE_API_TOKEN={token}") + else: + os.environ.pop(TOKEN_ENV_VAR, None) + + +def get_auth_headers() -> Dict[str, str]: + token = load_token() + if token: + return {'X-Auth-Token': token} + return {} + + +def update_token_from_response(data: Dict[str, Any]): + """Updates the stored token if present in the response data.""" + new_token = data.get('current_token') + if new_token: + logger.info('Received renewed token in response') + save_token(new_token) + + +# Helper Functions +def parse_key_value_pairs(items: List[str] | None) -> Dict[str, str]: + """Helper to parse 'key=value' strings from a list into a dict.""" + result: Dict[str, str] = {} + if not items: + return result + for item in items: + if '=' in item: + key, value = item.split('=', 1) + result[key.strip()] = value.strip() + else: + logging.warning(f"Skipping malformed pair: {item}") + return result + + +# Login +def login(client: httpx.Client, base_url: str, username: str, password: str): + """Logs in and saves the token.""" + url = f"{base_url}/login" + logger.info(f"Attempting login for user '{username}' at {url}...") + try: + # standard form encoding for OAuth2PasswordRequestForm -> expected bei Fastapi + response = client.post(url, data={'username': username, 'password': password}) + response.raise_for_status() + data = response.json() + token = data.get('access_token') + if token: + save_token(token) # workaround, export to env var in shell + print(f"Login successful. Token received: {token[:5]}...") + else: + logger.error(' No access token received in response.') + sys.exit(1) + except httpx.RequestError as e: + logger.error(f"HTTP Request Error during login: {e}") + sys.exit(1) + except httpx.HTTPStatusError as e: + logger.error(f"Login failed: {e.response.status_code}") + sys.exit(1) + except Exception as e: + logger.error(f"Unexpected error during login: {e}", exc_info=True) + sys.exit(1) + + +def get_instance_state_from_server(client: httpx.Client, base_url: str, instance_id: str): + """Requests the state of a specific instance.""" + url = f"{base_url}/instances/{instance_id}/state" + logger.info(f"Requesting state for instance {instance_id} at {url}...") + try: + response = client.get(url, headers=get_auth_headers()) + response.raise_for_status() + data = response.json() + update_token_from_response(data) + print(f"\n State for Instance {instance_id} ") + print(yaml.dump(data.get('variables', {}), indent=2)) + except httpx.RequestError as e: + logger.error(f"HTTP Request Error getting state: {e}") + except httpx.HTTPStatusError as e: + logger.error(f"HTTP Status Error getting state: {e.response.status_code} - {e.response.text}") + except Exception as e: + logger.error(f"Unexpected error getting state: {e}", exc_info=True) + + +def run_playbook_yaml( + client: httpx.Client, base_url: str, playbook_file: str, debug: bool = False +): + """Sends playbook YAML content to the server.""" + url = f"{base_url}/playbooks/execute/yaml" + logger.info(f"Attempting to execute playbook from local file: {playbook_file}") + try: + with open(playbook_file, 'r') as f: + playbook_yaml_content = f.read() + except Exception as e: + logger.error(f"Error reading file '{playbook_file}': {e}") + sys.exit(1) + try: + params = {'debug': True} if debug else {} + response = client.post(url, content=playbook_yaml_content, headers={**get_auth_headers(), + 'Content-Type': 'application/yaml'}, params=params) + response.raise_for_status() + data = response.json() + update_token_from_response(data) + print('\n Playbook YAML Execution Result ') + print(f"Success: {data.get('success')}") + print(f"Message: {data.get('message')}") + print(f"ID: {data.get('instance_id')}") + if data.get('final_state'): + print('\n Final Variable Store State ') + print(yaml.dump(data['final_state'].get('variables', {}), indent=2)) + if not data.get('success'): + sys.exit(1) + except httpx.RequestError as e: + logger.error(f"HTTP Request Error executing playbook YAML: {e}") + sys.exit(1) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP Status Error (YAML): {e.response.status_code} - {e.response.text}") + sys.exit(1) + except Exception as e: + logger.error(f"Unexpected error (YAML): {e}", exc_info=True) + sys.exit(1) + + +def run_playbook_file( + client: httpx.Client, + base_url: str, + playbook_file_path_on_server: str, + debug: bool = False +): + """Requests server to execute a playbook from local path.""" + url = f"{base_url}/playbooks/execute/file" + logger.info(f"Requesting server execute playbook file: {playbook_file_path_on_server}") + payload = {'file_path': playbook_file_path_on_server} + try: + params = {'debug': True} if debug else {} + response = client.post(url, json=payload, params=params, headers=get_auth_headers()) + response.raise_for_status() + data = response.json() + update_token_from_response(data) + print('\n Playbook File Execution Result ') + print(f"Success: {data.get('success')}") + print(f"Message: {data.get('message')}") + print(f"ID: {data.get('instance_id')}") + if data.get('final_state'): + print('\n Final Variable Store State ') + print(yaml.dump(data['final_state'].get('variables', {}), indent=2)) + if not data.get('success'): + sys.exit(1) + except httpx.RequestError as e: + logger.error(f"HTTP Request Error executing playbook file: {e}") + sys.exit(1) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP Status Error (File): {e.response.status_code} - {e.response.text}") + sys.exit(1) + except Exception as e: + logger.error(f"Unexpected error (File): {e}", exc_info=True) + sys.exit(1) + + +def run_command(client: httpx.Client, base_url: str, args): + """Sends a single command to the server for execution against an instance.""" + # TODO this need more special handling for sliver commands? + type = args.type + + url = f"{base_url}/command/{type.replace('_', '-')}" + logger.info(f"Attempting command '{type}' on instance 'default_context' at {url}") + + # Construct the request body dictionary (matching Pydantic model) + # Exclude argparse internals + body_dict: Dict[str, Any] = {} + excluded_args = {'mode', 'func', 'server_url', 'instance_id'} + for arg_name, arg_value in vars(args).items(): + if arg_name not in excluded_args and arg_value is not None: + pydantic_field_name = arg_name + # Type conversions for body (Pydantic/FastAPI handles validation, but ensure basic types) + if arg_name in [ + 'option', 'payload_option', 'metadata', 'prompts', + 'output_map', 'header', 'cookie', 'data' + ]: + if isinstance(arg_value, list): + body_dict[pydantic_field_name] = parse_key_value_pairs(arg_value) + else: + body_dict[pydantic_field_name] = arg_value + + try: + logger.debug(f"Sending POST to {url}") + logger.debug(f"Request Body: {json.dumps(body_dict, indent=2)}") + response = client.post(url, json=body_dict, headers=get_auth_headers()) + response.raise_for_status() + data = response.json() + update_token_from_response(data) + logger.info(f"Received response from /{type} endpoint.") + logger.debug(f"Response data: {data}") + + result = data.get('result', {}) + state = data.get('state', {}).get('variables', {}) + + print('\n--- Command Result ---') + print(f"Success: {result.get('success')}") + print(f"Return Code: {result.get('returncode')}") + print(f"Stdout:\n{result.get('stdout')}") + if result.get('error_message'): + print(f"Error Message: {result.get('error_message')}") + + print('\n--- Updated Variable Store State --- ') + print(yaml.dump(state, indent=2, default_flow_style=False)) + + is_background = hasattr(args, 'background') and args.background is True + if not result.get('success') or (result.get('returncode') != 0 and not is_background): + sys.exit(1) + + except httpx.RequestError as e: + logger.error(f"HTTP Request Error executing command: {e}") + sys.exit(1) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP Status Error ({url}): {e.response.status_code} - {e.response.text}") + sys.exit(1) + except Exception as e: + logger.error(f"Unexpected error executing command: {e}", exc_info=True) + sys.exit(1) + + +# Main Execution Logic +def main(): + parser = argparse.ArgumentParser(description='AttackMate REST API Client') + parser.add_argument('--base-url', default='https://localhost:8443', + help='Base URL of the AttackMate API server') + parser.add_argument( + '--cacert', + help='Path to the server\'s self-signed certificate file (cert.pem) for verification.' + ) + subparsers = parser.add_subparsers(dest='mode', required=True, help='Operation mode') + + # Login Mode + parser_login = subparsers.add_parser('login', help='Authenticate and get a token') + parser_login.add_argument('username', help='API username') + parser_login.add_argument('password', help='API password') + + # Playbook Modes + parser_pb_yaml = subparsers.add_parser( + 'playbook-yaml', help='Execute a playbook from a local YAML file content') + parser_pb_yaml.add_argument('playbook_file', help='Path to the local playbook YAML file') + parser_pb_yaml.add_argument('--debug', action=argparse.BooleanOptionalAction, + help='Enable server debug logging for this instance') + + parser_pb_file = subparsers.add_parser( + 'playbook-file', help='Request server execute a playbook from its filesystem') + parser_pb_file.add_argument('server_playbook_path', + help='Path to the playbook file relative to the server\'s allowed directory') + parser_pb_file.add_argument('--debug', action=argparse.BooleanOptionalAction, + help='Enable server debug logging for this instance') + + parser_inst_state = subparsers.add_parser('instance-state', help='Get the state of an instance') + parser_inst_state.add_argument('instance_id', help='ID of the instance') + + # Command Mode + parser_command = subparsers.add_parser('command', help='Execute a single command on a specific instance') + command_subparsers = parser_command.add_subparsers( + dest='type', required=True, help='Specific command type') + + # Define Common Arguments Parser (used as parent for command types) + common_args_parser = argparse.ArgumentParser(add_help=False) + common_args_parser.add_argument('--only-if', help='Conditional execution string') + common_args_parser.add_argument('--error-if', help='Regex pattern for error on match') + common_args_parser.add_argument('--error-if-not', help='Regex pattern for error if no match') + common_args_parser.add_argument('--loop-if', help='Regex pattern to loop on match') + common_args_parser.add_argument('--loop-if-not', help='Regex pattern to loop if no match') + common_args_parser.add_argument('--loop-count', type=int, help='Maximum loop iterations') + common_args_parser.add_argument( + '--exit-on-error', + action=argparse.BooleanOptionalAction, + help='Exit if command return code is non-zero' + ) + common_args_parser.add_argument('--save', help='File path to save command stdout') + common_args_parser.add_argument( + '--background', action=argparse.BooleanOptionalAction, help='Run command in the background') + common_args_parser.add_argument( + '--kill-on-exit', action=argparse.BooleanOptionalAction, help='Kill process on server exit') + common_args_parser.add_argument('--metadata', action='append', + help='Metadata key=value pair (repeatable)') + + # Add Command Subparsers + # Shell + parser_shell = command_subparsers.add_parser( + 'shell', help='Execute shell command', parents=[common_args_parser]) + parser_shell.add_argument('cmd', help='The command to execute') + parser_shell.add_argument('--interactive', action=argparse.BooleanOptionalAction) + parser_shell.add_argument('--creates-session', help='Name of shell session to create') + parser_shell.add_argument('--session', help='Name of existing shell session to use') + parser_shell.add_argument('--command-timeout', type=int) + parser_shell.add_argument('--read', action=argparse.BooleanOptionalAction, default=True) + parser_shell.add_argument('--command-shell', help='Shell path') + parser_shell.add_argument('--bin', action=argparse.BooleanOptionalAction) + + # Sleep + parser_sleep = command_subparsers.add_parser( + 'sleep', help='Pause execution', parents=[common_args_parser]) + parser_sleep.add_argument('--seconds', type=int) + parser_sleep.add_argument('--min-sec', type=int) + parser_sleep.add_argument('--random', action=argparse.BooleanOptionalAction) + + # Debug + parser_debug = command_subparsers.add_parser( + 'debug', help='Debug output or pause', parents=[common_args_parser]) + parser_debug.add_argument('--varstore', action=argparse.BooleanOptionalAction) + parser_debug.add_argument('--exit', action=argparse.BooleanOptionalAction) + parser_debug.add_argument('--wait-for-key', action=argparse.BooleanOptionalAction) + + # SetVar + parser_setvar = command_subparsers.add_parser( + 'setvar', help='Set a variable', parents=[common_args_parser]) + parser_setvar.add_argument('variable', help='Name of the variable to set') + parser_setvar.add_argument('cmd', help='Value to assign to the variable') + parser_setvar.add_argument('--encoder', help='Encoder to use') + + # Mktemp (Tempfile) + parser_mktemp = command_subparsers.add_parser( + 'mktemp', help='Create temporary file/directory', parents=[common_args_parser]) + parser_mktemp.add_argument('variable', help='Variable name to store the path') + parser_mktemp.add_argument('--cmd', choices=['file', 'dir'], + default='file', help='create a file or directory') + + # ADD SUBPARSERS FOR ALL OTHER COMMAND TYPES HERE + + args = parser.parse_args() + if args.cacert: + cert_path = os.path.abspath(args.cacert) # Ensure absolute path + if os.path.exists(cert_path): + logger.info(f"Configured httpx to verify using CA cert: {cert_path}") + else: + logger.error(f"CA certificate file not found at specified path: {cert_path}") + sys.exit(1) + + # Create HTTP Client + with httpx.Client(base_url=args.base_url, timeout=60.0, verify=cert_path) as client: + try: + # Execute based on mode + if args.mode == 'login': + login(client, args.base_url, args.username, args.password) + if args.mode == 'playbook-yaml': + run_playbook_yaml(client, args.base_url, args.playbook_file, args.debug) + elif args.mode == 'playbook-file': + run_playbook_file(client, args.base_url, args.server_playbook_path, args.debug) + elif args.mode == 'instance-state': + get_instance_state_from_server(client, args.base_url, args.instance_id) + elif args.mode == 'command': + if hasattr(args, 'type') and args.type: + run_command(client, args.base_url, args) + else: + logger.error('Internal error: Command mode selected but no command type specified.') + parser.print_help() + sys.exit(1) + except httpx.ConnectError as e: + logger.error( + f"Connection Error: Could not connect to {args.base_url}. " + f"Is the server running with HTTPS? Did you provide cert? Details: {e}" + ) + sys.exit(1) + except Exception as main_err: + logger.error(f"Client execution failed: {main_err}", exc_info=True) + sys.exit(1) + + logger.info('Client finished.') + + +if __name__ == '__main__': + main() diff --git a/remote_rest/create_hashes.py b/remote_rest/create_hashes.py new file mode 100644 index 00000000..c073f928 --- /dev/null +++ b/remote_rest/create_hashes.py @@ -0,0 +1,18 @@ +import os + +from passlib.context import CryptContext + +pwd_context = CryptContext(schemes=['argon2'], deprecated='auto') + + +users = { + 'testuser': 'testuser', +} + +env_content = '' +print('\nCopy the following lines into your .env file:\n') +for username, plain_password in users.items(): + hashed_password = pwd_context.hash(plain_password) + env_line = f"USER_{username.upper()}_HASH=\"{hashed_password}\"" + print(env_line) + env_content += env_line + '\n' diff --git a/remote_rest/log_utils.py b/remote_rest/log_utils.py new file mode 100644 index 00000000..5a1a3b73 --- /dev/null +++ b/remote_rest/log_utils.py @@ -0,0 +1,96 @@ +import datetime +import logging +import os +from contextlib import contextmanager +from typing import Generator, List, Optional + +# directory for instance logs if running from project root: +LOG_DIR = os.path.join(os.getcwd(), "attackmate_server_logs") +# Or absolute path: +# LOG_DIR = "/var/log/attackmate_instances" # must exists and has write permissions + +# List of logger names to add instance-specific handlers to +TARGET_LOGGER_NAMES = ['playbook', 'output', 'json'] + +# Create formatter for the instance files +instance_log_formatter = logging.Formatter( + '%(asctime)s %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +json_log_formatter = logging.Formatter('%(message)s') + + +@contextmanager +def instance_logging(instance_id: str, log_level: int = logging.INFO): + """cd + Context manager to temporarily add a file handler for a specific instance + to the target AttackMate loggers. + """ + handlers: List[logging.FileHandler] = [] + instance_output_log_file = None # Initialize + instance_attackmate_log_file = None + + try: + # log directory exists + os.makedirs(LOG_DIR, exist_ok=True) + + timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + instance_output_log_file = os.path.join(LOG_DIR, f"{timestamp}_{instance_id}_output.log") + instance_attackmate_log_file = os.path.join(LOG_DIR, f"{timestamp}_{instance_id}_attackmate.log") + instance_json_log_file = os.path.join(LOG_DIR, f"{timestamp}_{instance_id}_attackmate.json") + + # instance-specific file handler + # 'a' to append within the same request if multiple logs occur + # each request gets a new timestamped file. + output_file_handler = logging.FileHandler(instance_output_log_file, mode='a') + output_file_handler.setFormatter(instance_log_formatter) + output_file_handler.setLevel(log_level) + + attackmate_file_handler = logging.FileHandler(instance_attackmate_log_file, mode='a') + attackmate_file_handler.setFormatter(instance_log_formatter) + attackmate_file_handler.setLevel(log_level) + + attackmate_json_handler = logging.FileHandler(instance_json_log_file, mode='a') + attackmate_json_handler.setFormatter(json_log_formatter) + attackmate_json_handler.setLevel(log_level) + + # Add the handler to the target loggers + for logger_name in TARGET_LOGGER_NAMES: + logger = logging.getLogger(logger_name) + logger.setLevel(log_level) + logger.propagate = False + if logger_name == 'playbook': + logger.addHandler(attackmate_file_handler) + handlers.append(attackmate_file_handler) # remove later finally + if logger_name == 'output': + logger.addHandler(output_file_handler) + handlers.append(output_file_handler) # remove later in finally + if logger_name == 'json': + logger.addHandler(attackmate_json_handler) + handlers.append(attackmate_json_handler) # remove later in finally + logging.info( + (f"Added instance log handlers for '{instance_id}' to logger '{logger_name}' -> " + f"{instance_output_log_file} and {instance_attackmate_log_file} and {instance_json_log_file}.")) + yield [ + instance_attackmate_log_file, + instance_output_log_file, + instance_json_log_file + ] # 'with' block executes here and uses these paths + + except Exception as e: + logging.error(f"Error setting up instance logging for '{instance_id}': {e}", exc_info=True) + yield # main code execution if logging fails + + finally: + logging.info(f"Removing instance log handlers for '{instance_id}'...") + for handler in handlers: + for logger_name in TARGET_LOGGER_NAMES: + logger = logging.getLogger(logger_name) + if handler in logger.handlers: + logger.removeHandler(handler) + try: + handler.close() + except Exception as e: + logging.error( + f"Error removing/closing log handler for instance '{instance_id}': {e}", exc_info=True) + logging.info(f"Instance log handlers removed for '{instance_id}'.") diff --git a/remote_rest/main.py b/remote_rest/main.py new file mode 100644 index 00000000..8af51041 --- /dev/null +++ b/remote_rest/main.py @@ -0,0 +1,168 @@ +from contextlib import asynccontextmanager +import sys +from typing import AsyncGenerator +import os +import uvicorn +from attackmate.execexception import ExecException +from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi.responses import JSONResponse +from fastapi.security import OAuth2PasswordRequestForm +from src.attackmate.attackmate import AttackMate +from src.attackmate.logging_setup import (initialize_json_logger, + initialize_logger, + initialize_output_logger, + initialize_api_logger) +from src.attackmate.playbook_parser import parse_config + +import remote_rest.state as state +from remote_rest.routers import commands, instances, playbooks + +from .auth_utils import create_access_token, get_user_hash, verify_password +from .schemas import TokenResponse + +CERT_DIR = os.path.dirname(os.path.abspath(__file__)) +KEY_FILE = os.path.join(CERT_DIR, 'key.pem') +CERT_FILE = os.path.join(CERT_DIR, 'cert.pem') + +# Logging +initialize_logger(debug=True, append_logs=False) +initialize_output_logger(debug=True, append_logs=False) +initialize_json_logger(json=True, append_logs=False) +logger = initialize_api_logger(debug=True, append_logs=False) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + # Code to run before the application starts accepting requests + logger.info('AttackMate API starting up (lifespan)...') + try: + # Load global config on startup and assign to the variable in state.py + loaded_config = parse_config(config_file=None, logger=logger) + if loaded_config: + state.attackmate_config = loaded_config + logger.info('Global AttackMate configuration loaded.') + else: + raise RuntimeError( + 'Failed to load essential AttackMate configuration (parse_config returned None).') + # Initialize the INSTANCES dict (it's already defined globally in state.py) + state.INSTANCES.clear() + # instantiate the Instance in the INSTANCES dict + state.INSTANCES['default_context'] = AttackMate(playbook=None, config=loaded_config, varstore=None) + logger.info('Instances dictionary initialized.') + # any other async startup tasks ? + + except Exception as e: + logger.critical(f"Failed to initialize during startup lifespan: {e}", exc_info=True) + raise RuntimeError(f"Failed to initialize application state: {e}") from e + + yield # Application runs here + + # Code to run when the application is shutting down + logger.warning('AttackMate API shutting down (lifespan)... Cleaning up instances...') + instance_ids = list(state.INSTANCES.keys()) + for instance_id in instance_ids: + instance = state.INSTANCES.pop(instance_id, None) + if instance: + logger.info(f"Cleaning up instance {instance_id}...") + try: + # blocking? + instance.clean_session_stores() + instance.pm.kill_or_wait_processes() + except Exception as e: + logger.error(f"Error cleaning up instance {instance_id}: {e}", exc_info=True) + logger.info('Instance cleanup complete (lifespan).') + + +app = FastAPI( + title='AttackMate API', + description='API for remote control of AttackMate instances and playbook execution.', + version='1.0.0', + lifespan=lifespan) + + +# Exception Handling +@app.exception_handler(ExecException) +async def attackmate_execution_exception_handler(request: Request, exc: ExecException): + logger.error(f"AttackMate Execution Exception: {exc}") + return JSONResponse( + status_code=400, + content={ + 'detail': 'AttackMate command execution failed', + 'error_message': str(exc), + 'instance_id': None + }, + ) + + +@app.exception_handler(Exception) +async def generic_exception_handler(request: Request, exc: Exception): + if isinstance(exc, SystemExit): + logger.error(f"Command triggered SystemExit with code {exc.code}") + return JSONResponse( + status_code=400, # client-side error pattern + content={ + 'detail': 'Command execution led to termination request', + 'error_message': ( + f"SystemExit triggered (likely due to error condition like 'exit_on_error'). " + f"Exit code: {exc.code}" + ), + 'instance_id': None + }, + ) + # Re-raise other exceptions for specific hanfling? + raise exc + + +# Login endpoint +@app.post('/login', response_model=TokenResponse, tags=['Auth']) +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + """Authenticates user and returns an access token.""" + logger.info(f"Login attempt for user: {form_data.username}") + hashed_password = get_user_hash(form_data.username) + if not hashed_password: + logger.warning(f"Login failed: User '{form_data.username}' not found.") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Incorrect username or password', + headers={'WWW-Authenticate': 'Bearer'}, + ) + + if not verify_password(form_data.password, hashed_password): + logger.warning(f"Login failed: Invalid password for user '{form_data.username}'.") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Incorrect username or password', + headers={'WWW-Authenticate': 'Bearer'}, + ) + + # vlid password -> create token + access_token = create_access_token(username=form_data.username) + logger.info(f"Login successful for user '{form_data.username}'. Token created.") + # Return token + return TokenResponse(access_token=access_token, token_type='bearer') + +# Include Routers +app.include_router(instances.router, prefix='/instances') +app.include_router(playbooks.router) +app.include_router(commands.router) + + +# Root Endpoint +@app.get('/', include_in_schema=False) +async def root(): + return {'message': 'AttackMate API is running. Use /login to authenticate. See /docs.'} + +if __name__ == '__main__': + if not os.path.exists(KEY_FILE): + logger.critical(f"SSL Error: Key file not found at {KEY_FILE}") + sys.exit(1) + if not os.path.exists(CERT_FILE): + logger.critical(f"SSL Error: Certificate file not found at {CERT_FILE}") + sys.exit(1) + uvicorn.run('remote_rest.main:app', + host='0.0.0.0', + port=8443, + reload=False, + + ssl_keyfile=KEY_FILE, + ssl_certfile=CERT_FILE) diff --git a/remote_rest/routers/__init__.py b/remote_rest/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/remote_rest/routers/commands.py b/remote_rest/routers/commands.py new file mode 100644 index 00000000..ec703f01 --- /dev/null +++ b/remote_rest/routers/commands.py @@ -0,0 +1,65 @@ +import logging +from typing import Optional +from pydantic import BaseModel +from attackmate.attackmate import AttackMate + +from attackmate.schemas.base import BaseCommand +from fastapi import APIRouter, Depends, Header, HTTPException +from src.attackmate.execexception import ExecException +from src.attackmate.result import Result as AttackMateResult +from src.attackmate.schemas.command_types import Command + + +from remote_rest.auth_utils import API_KEY_HEADER_NAME, get_current_user +from remote_rest.schemas import CommandResultModel, ExecutionResponseModel +from remote_rest.utils import varstore_to_state_model + +from ..state import get_persistent_instance + + +router = APIRouter(prefix='/command', tags=['Commands']) +logger = logging.getLogger('attackmate_api') + + +class CommandRequest(BaseModel): + command: Command + + +async def run_command_on_instance(instance: AttackMate, command_data: BaseCommand) -> AttackMateResult: + """Runs a command on a given AttackMate instance.""" + try: + logger.info(f"Executing command type '{command_data.type}' on instance") # type: ignore + # TODO does this work? need to pass command class object here? + result = instance.run_command(command_data) + logger.info(f"Command execution finished. RC: {result.returncode}") + return result + except (ExecException, SystemExit) as e: + logger.error(f"AttackMate execution error: {e}", exc_info=True) + raise e + except Exception as e: + logger.error(f"Unexpected error during instance.run_command: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error during command execution: {e}") + + +@router.post('/execute', response_model=ExecutionResponseModel) +async def execute_unified_command( + command_request: CommandRequest, + instance: AttackMate = Depends(get_persistent_instance), + current_user: str = Depends(get_current_user), + x_auth_token: Optional[str] = Header(None, alias=API_KEY_HEADER_NAME) +): + # command_request.command will be the correct Pydantic type based on doscriminated union in RemoteCommand + attackmate_result = await run_command_on_instance(instance, command_request.command) + + result_model = CommandResultModel( + success=(attackmate_result.returncode == 0 if attackmate_result.returncode is not None else True), + stdout=attackmate_result.stdout, + returncode=attackmate_result.returncode + ) + state_model = varstore_to_state_model(instance.varstore) + return ExecutionResponseModel( + result=result_model, + state=state_model, + instance_id='default-context', + current_token=x_auth_token + ) diff --git a/remote_rest/routers/instances.py b/remote_rest/routers/instances.py new file mode 100644 index 00000000..2f4d7b0d --- /dev/null +++ b/remote_rest/routers/instances.py @@ -0,0 +1,29 @@ +import logging + +from fastapi import APIRouter, Depends +from src.attackmate.attackmate import AttackMate + +from remote_rest.auth_utils import get_current_user +from remote_rest.schemas import VariableStoreStateModel +from remote_rest.utils import varstore_to_state_model + +from ..state import get_instance_by_id, get_persistent_instance + +router = APIRouter(tags=['Instances']) +logger = logging.getLogger(__name__) + + +@router.get('/{instance_id}/state', response_model=VariableStoreStateModel) +async def get_instance_state( + instance: AttackMate = Depends(get_instance_by_id), + current_user: str = Depends(get_current_user) +): + return varstore_to_state_model(instance.varstore) + + +@router.get('/state', response_model=VariableStoreStateModel) +async def get_persistent_instance_state( + instance: AttackMate = Depends(get_persistent_instance), + current_user: str = Depends(get_current_user) +): + return varstore_to_state_model(instance.varstore) diff --git a/remote_rest/routers/playbooks.py b/remote_rest/routers/playbooks.py new file mode 100644 index 00000000..8e75d14e --- /dev/null +++ b/remote_rest/routers/playbooks.py @@ -0,0 +1,182 @@ +import logging +import os +import uuid +from typing import Optional + +import yaml +from attackmate.schemas.playbook import Playbook +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from pydantic import ValidationError +from src.attackmate.attackmate import AttackMate +from src.attackmate.playbook_parser import parse_playbook + +from remote_rest.auth_utils import API_KEY_HEADER_NAME, get_current_user +from remote_rest.schemas import PlaybookFileRequest, PlaybookResponseModel +from remote_rest.utils import varstore_to_state_model + +from ..log_utils import instance_logging +from ..state import attackmate_config + +router = APIRouter(prefix='/playbooks', tags=['Playbooks']) +logger = logging.getLogger(__name__) +ALLOWED_PLAYBOOK_DIR = '/usr/local/share/attackmate/remote_playbooks/' # MUST EXIST + + +# helper t0 read logfile +def read_log_file(log_path: Optional[str]) -> Optional[str]: + if not log_path or not os.path.exists(log_path): + return None + try: + with open(log_path, 'r') as f: + return f.read() + except Exception as e: + logger.error(f"Failed to read log file '{log_path}': {e}") + return f"Error reading log file: {e}" + +# Playbook Execution + + +@router.post('/execute/yaml', response_model=PlaybookResponseModel) +async def execute_playbook_from_yaml(playbook_yaml: str = Body(..., media_type='application/yaml'), + debug: bool = Query( + False, + description="Enable debug logging for this request's instance log." +), + current_user: str = Depends(get_current_user), + x_auth_token: Optional[str] = Header(None, alias=API_KEY_HEADER_NAME)): + """ + Executes a playbook provided as YAML content in the request body. + Use a transient AttackMate instance. + """ + logger.info('Received request to execute playbook from YAML content.') + instance_id = str(uuid.uuid4()) + log_level = logging.DEBUG if debug else logging.INFO + with instance_logging(instance_id, log_level) as log_files: + attackmate_log_path, output_log_path, json_log_path = log_files + try: + playbook_dict = yaml.safe_load(playbook_yaml) + if not playbook_dict: + raise ValueError('Received empty or invalid playbook YAML content.') + playbook = Playbook.model_validate(playbook_dict) + logger.info(f"Creating transient AttackMate instance, ID: {instance_id}") + am_instance = AttackMate(playbook=playbook, config=attackmate_config, varstore=None) + return_code = am_instance.main() + final_state = varstore_to_state_model(am_instance.varstore) + logger.info(f"Transient playbook execution finished. return code {return_code}") + attackmate_log = read_log_file(attackmate_log_path) + output_log = read_log_file(output_log_path) + json_log = read_log_file(json_log_path) + except (yaml.YAMLError, ValidationError, ValueError) as e: + logger.error(f"Playbook validation/parsing error: {e}") + raise HTTPException(status_code=422, detail=f"Invalid playbook YAML: {e}") + except Exception as e: + logger.error(f"Unexpected error during playbook execution: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Server error during playbook execution: {e}") + finally: + if am_instance: + logger.info('Cleaning up transient playbook instance.') + try: + am_instance.clean_session_stores() + am_instance.pm.kill_or_wait_processes() + except Exception as cleanup_e: + logger.error(f"Error cleaning transient instance: {cleanup_e}", exc_info=True) + + return PlaybookResponseModel( + success=(return_code == 0), + message='Playbook execution finished.', + final_state=final_state, + instance_id=instance_id, + attackmate_log=attackmate_log, + output_log=output_log, + json_log=json_log, + current_token=x_auth_token + ) + + +@router.post('/execute/file', response_model=PlaybookResponseModel) +async def execute_playbook_from_file(request_body: PlaybookFileRequest, + debug: bool = Query( + False, + description=( + "Enable debug level logging for this request's instance log." + ) + ), + current_user: str = Depends(get_current_user), + x_auth_token: Optional[str] = Header(None, alias=API_KEY_HEADER_NAME) + ): + """ + Executes a playbook located at a specific path *on the server*. + Uses a transient AttackMate instance. + """ + # TODO ensure this only executes playbooks in certain locations -> read up on path traversal + logger.info(f"Received request to execute playbook from file: {request_body.file_path}") + try: + # base directory exists + if not os.path.isdir(ALLOWED_PLAYBOOK_DIR): + logger.error( + f"Configuration error: ALLOWED_PLAYBOOK_DIR '{ALLOWED_PLAYBOOK_DIR}' does not exist.") + raise HTTPException( + status_code=500, detail='Server configuration error: Playbook directory not found.') + + requested_path = os.path.normpath(request_body.file_path) + # Disallow absolute paths or paths trying to go up directories + if os.path.isabs(requested_path) or requested_path.startswith('..'): + raise ValueError('Invalid playbook path specified.') + + full_path = os.path.join(ALLOWED_PLAYBOOK_DIR, requested_path) + # Final check: ensure the resolved path is still within the allowed directory + if not os.path.abspath(full_path).startswith(os.path.abspath(ALLOWED_PLAYBOOK_DIR)): + raise ValueError('Invalid playbook path specified (path traversal attempt ).') + + # Check if the file exists + if not os.path.isfile(full_path): + raise FileNotFoundError(f"Playbook file not found atpath: {full_path}") + + except (ValueError, FileNotFoundError) as e: + logger.error(f"Invalid or non-existent playbook path requested: {request_body.file_path} -> {e}") + raise HTTPException(status_code=400, detail=f"Invalid or non-existent playbook file path: {e}") + except Exception as e: + logger.error(f"Error processing playbook path: {e}", exc_info=True) + raise HTTPException(status_code=500, detail='Server error processing file path.') + + instance_id = str(uuid.uuid4()) + log_level = logging.DEBUG if debug else logging.INFO + with instance_logging(instance_id, log_level) as log_files: + attackmate_log_path, output_log_path, json_log_path = log_files + try: + logger.info(f"Parsing playbook from: {full_path}") + playbook = parse_playbook(full_path, logger) + logger.info(f"Creating transient AttackMate instance, ID: {instance_id}") + am_instance = AttackMate(playbook=playbook, config=attackmate_config, varstore=None) + return_code = am_instance.main() + final_state = varstore_to_state_model(am_instance.varstore) + logger.info(f"Transient playbook execution finished. RC: {return_code}") + attackmate_log = read_log_file(attackmate_log_path) + output_log = read_log_file(output_log_path) + json_log = read_log_file(json_log_path) + except (ValidationError, ValueError) as e: + logger.error(f"Playbook validation error from file '{full_path}': {e}") + raise HTTPException( + status_code=400, detail=f"Invalid playbook content in file '{request_body.file_path}': {e}") + except Exception as e: + logger.error(f"Unexpected error during playbook file execution: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Server error during playbook execution: {e}") + finally: + if am_instance: + logger.info('Cleaning up transient playbook instance.') + try: + am_instance.clean_session_stores() + am_instance.pm.kill_or_wait_processes() + except Exception as e: + logger.error(f"Cleanup error: {e}") + + return PlaybookResponseModel( + success=(return_code == 0), + message=f"Playbook '{request_body.file_path}' execution finished.", + final_state=final_state, + instance_id=instance_id, + attackmate_log=attackmate_log, + output_log=output_log, + json_log=json_log, + current_token=x_auth_token + ) diff --git a/remote_rest/schemas.py b/remote_rest/schemas.py new file mode 100644 index 00000000..b1662098 --- /dev/null +++ b/remote_rest/schemas.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + + +class VariableStoreStateModel(BaseModel): + variables: Dict[str, Any] = {} + + +class CommandResultModel(BaseModel): + success: bool + stdout: Optional[str] = None + returncode: Optional[int] = None + error_message: Optional[str] = None + + +class ExecutionResponseModel(BaseModel): + result: CommandResultModel + state: VariableStoreStateModel + instance_id: Optional[str] = None + current_token: Optional[str] = Field(None, description='Renewed auth token for subsequent requests.') + + +class PlaybookResponseModel(BaseModel): + success: bool + message: str + final_state: Optional[VariableStoreStateModel] = None + instance_id: Optional[str] = None + attackmate_log: Optional[str] = Field(None, description='Content of the attackmate.log for this run.') + output_log: Optional[str] = Field(None, description='Content of the output.log for this run.') + json_log: Optional[str] = Field(None, description='Content of the attackmate.json for this run.') + current_token: Optional[str] = Field(None, description='Renewed auth token for subsequent requests.') + + +class InstanceCreationResponse(BaseModel): + instance_id: str + message: str + + +class PlaybookFileRequest(BaseModel): + file_path: str = Field(..., + description='Path to the playbook file RELATIVE to a predefined server directory.') + debug: bool = Field( + False, + description='If true, the playbook will be executed in debug mode. ' + ) + + +class TokenResponse(BaseModel): + access_token: str + token_type: str = 'bearer' diff --git a/remote_rest/state.py b/remote_rest/state.py new file mode 100644 index 00000000..a8f0bbe1 --- /dev/null +++ b/remote_rest/state.py @@ -0,0 +1,43 @@ +from typing import Dict, Optional + +from fastapi import Depends, HTTPException, Path +from src.attackmate.attackmate import AttackMate +from src.attackmate.schemas.config import Config + +# Define shared state variables here +INSTANCES: Dict[str, AttackMate] = {} +attackmate_config: Optional[Config] = None + + +def get_instances_dict() -> Dict[str, AttackMate]: + """Dependency to get the shared INSTANCES dictionary.""" + # returns the global dict reference + return INSTANCES + + +def get_attackmate_config() -> Config: + """Dependency to get the shared AttackMate configuration.""" + if attackmate_config is None: + raise RuntimeError('Server configuration is not available.') + return attackmate_config + + +def get_instance_by_id( + instance_id: str = Path(...), + instances: Dict[str, AttackMate] = Depends(get_instances_dict) +) -> AttackMate: + """Dependency to get a specific AttackMate instance, raising 404 if not found.""" + instance = instances.get(instance_id) + if not instance: + raise HTTPException(status_code=404, detail=f"AttackMate instance '{instance_id}' not found.") + return instance + + +def get_persistent_instance( + instances: Dict[str, AttackMate] = Depends(get_instances_dict) +) -> AttackMate: + """Dependency to get the default context persistent AttackMate instance, raising 404 if not found.""" + instance = instances.get('default_context') + if not instance: + raise HTTPException(status_code=404, detail='Persistent AttackMate instance not found.') + return instance diff --git a/remote_rest/utils.py b/remote_rest/utils.py new file mode 100644 index 00000000..376264f5 --- /dev/null +++ b/remote_rest/utils.py @@ -0,0 +1,16 @@ +import logging +from typing import Any, Dict + +from src.attackmate.variablestore import VariableStore + +from remote_rest.schemas import VariableStoreStateModel + +logger = logging.getLogger(__name__) + + +def varstore_to_state_model(varstore: VariableStore) -> VariableStoreStateModel: + """Converts AttackMate VariableStore to Pydantic VariableStoreStateModel.""" + combined_vars: Dict[str, Any] = {} + combined_vars.update(varstore.variables) + combined_vars.update(varstore.lists) + return VariableStoreStateModel(variables=combined_vars) diff --git a/src/attackmate/attackmate.py b/src/attackmate/attackmate.py index 72230ad1..261b2dad 100644 --- a/src/attackmate/attackmate.py +++ b/src/attackmate/attackmate.py @@ -8,20 +8,21 @@ configuration. """ +import time from typing import Dict, Optional import logging from attackmate.result import Result import attackmate.executors as executors from attackmate.schemas.config import CommandConfig, Config, MsfConfig, SliverConfig -from attackmate.schemas.playbook import Playbook, Commands, Command -from .variablestore import VariableStore -from .processmanager import ProcessManager +from attackmate.schemas.playbook import Playbook +from attackmate.schemas.command_types import Commands, Command +from attackmate.variablestore import VariableStore +from attackmate.processmanager import ProcessManager from attackmate.executors.baseexecutor import BaseExecutor from attackmate.executors.executor_factory import executor_factory import asyncio - class AttackMate: def __init__( self, @@ -89,10 +90,14 @@ def _get_executor(self, command_type: str) -> BaseExecutor: return self.executors[command_type] def _run_commands(self, commands: Commands): + delay = self.pyconfig.cmd_config.command_delay or 0 + self.logger.info(f'Delay before commands: {delay} seconds') for command in commands: command_type = 'ssh' if command.type == 'sftp' else command.type executor = self._get_executor(command_type) if executor: + if command.type not in ('sleep', 'debug', 'setvar'): + time.sleep(delay) executor.run(command) def run_command(self, command: Command) -> Result: @@ -101,28 +106,25 @@ def run_command(self, command: Command) -> Result: if executor: result = executor.run(command) return result if result else Result(None, None) - + def clean_session_stores(self): self.logger.warning('Cleaning up session stores') # msf - if (msf_module_executor := self.executors.get("msf-module")): + if (msf_module_executor := self.executors.get('msf-module')): msf_module_executor.cleanup() - if (msf_session_executor := self.executors.get("msf-session")): + if (msf_session_executor := self.executors.get('msf-session')): msf_session_executor.cleanup() # ssh - if (ssh_executor := self.executors.get("ssh")): + if (ssh_executor := self.executors.get('ssh')): ssh_executor.cleanup() # vnc - if (vnc_executor := self.executors.get("vnc")): + if (vnc_executor := self.executors.get('vnc')): vnc_executor.cleanup() # sliver - if (sliver_executor := self.executors.get("sliver-session")): + if (sliver_executor := self.executors.get('sliver-session')): loop = asyncio.get_event_loop() loop.run_until_complete(sliver_executor.cleanup()) - - - def main(self): """The main function diff --git a/src/attackmate/executors/__init__.py b/src/attackmate/executors/__init__.py index 955058f0..9312ed95 100644 --- a/src/attackmate/executors/__init__.py +++ b/src/attackmate/executors/__init__.py @@ -1,5 +1,6 @@ from .browser.browserexecutor import BrowserExecutor from .shell.shellexecutor import ShellExecutor +from .remote.remoteexecutor import RemoteExecutor from .ssh.sshexecutor import SSHExecutor from .metasploit.msfsessionexecutor import MsfSessionExecutor from .metasploit.msfpayloadexecutor import MsfPayloadExecutor @@ -22,6 +23,7 @@ __all__ = [ + 'RemoteExecutor' 'BrowserExecutor', 'ShellExecutor', 'SSHExecutor', diff --git a/src/attackmate/executors/baseexecutor.py b/src/attackmate/executors/baseexecutor.py index fb673c44..4f71d71d 100644 --- a/src/attackmate/executors/baseexecutor.py +++ b/src/attackmate/executors/baseexecutor.py @@ -3,6 +3,8 @@ from datetime import datetime from typing import Any from collections import OrderedDict + +from pydantic import BaseModel from attackmate.executors.features.cmdvars import CmdVars from attackmate.executors.features.exitonerror import ExitOnError from attackmate.executors.features.looper import Looper @@ -79,13 +81,14 @@ def run(self, command: BaseCommand) -> Result: self.logger.debug(f"Template-Command: '{command.cmd}'") if command.background: # Background commands always return Result(None,None) + time_of_execution = datetime.now().isoformat() + self.log_json(self.json_logger, command, time_of_execution) result = self.exec_background(self.substitute_template_vars(command, self.substitute_cmd_vars)) else: result = self.exec(self.substitute_template_vars(command, self.substitute_cmd_vars)) return result - def log_command(self, command): """Log starting-status of the command""" self.logger.info(f"Executing '{command}'") @@ -117,14 +120,15 @@ def make_command_serializable(self, command, time): command_dict['parameters'] = dict() for key, value in command.__dict__.items(): - if key not in command_dict and key != 'commands': + if key not in command_dict and key != 'commands' and key != 'remote_command': command_dict['parameters'][key] = value # Handle nested "commands" recursively if key == 'commands' and isinstance(value, list): command_dict['parameters']['commands'] = [ self.make_command_serializable(sub_command, time) for sub_command in value ] - + if key == 'remote_command' and isinstance(value, BaseModel): + command_dict['parameters']['remote_command'] = self.make_command_serializable(value, time) return command_dict def save_output(self, command: BaseCommand, result: Result): diff --git a/src/attackmate/executors/common/loopexecutor.py b/src/attackmate/executors/common/loopexecutor.py index 1818e78d..7671a6fd 100644 --- a/src/attackmate/executors/common/loopexecutor.py +++ b/src/attackmate/executors/common/loopexecutor.py @@ -11,7 +11,7 @@ from attackmate.executors.features.conditional import Conditional from attackmate.result import Result from attackmate.schemas.loop import LoopCommand -from attackmate.schemas.playbook import Commands, Command +from attackmate.schemas.command_types import Commands, Command from attackmate.variablestore import VariableStore from attackmate.execexception import ExecException from attackmate.processmanager import ProcessManager @@ -54,7 +54,7 @@ def substitute_variables_in_command(self, command_obj, placeholders: dict): and command_obj.url = '$LOOP_ITEM', then it becomes 'https://example.com'. """ for attr_name, attr_val in vars(command_obj).items(): - if isinstance(attr_val, str) and "$" in attr_val: + if isinstance(attr_val, str) and '$' in attr_val: new_val = Template(attr_val).safe_substitute(placeholders) setattr(command_obj, attr_name, new_val) @@ -80,7 +80,7 @@ def loop_items(self, command: LoopCommand, varname: str, iterable: list[str]) -> 'LOOP_ITEM': x, **self.varstore.variables, } - + if self.break_condition_met(command, placeholders): return self.substitute_variables_in_command(template_cmd, placeholders) @@ -136,4 +136,4 @@ def _exec_cmd(self, command: LoopCommand) -> Result: # runfunc will replace global variables then self.execute_loop(command) self.logger.info('Loop execution complete') - return Result('', 0) \ No newline at end of file + return Result('', 0) diff --git a/src/attackmate/executors/features/background.py b/src/attackmate/executors/features/background.py index c2a9cb68..988b1240 100644 --- a/src/attackmate/executors/features/background.py +++ b/src/attackmate/executors/features/background.py @@ -1,3 +1,4 @@ +import json from attackmate.schemas.base import BaseCommand from attackmate.processmanager import ProcessManager from attackmate.result import Result @@ -30,7 +31,8 @@ def _create_queue(self) -> Optional[Queue]: def exec_background(self, command: BaseCommand) -> Result: self.logger.info(f'Run in background: {getattr(command, "type", "")}({command.cmd})') - + if command.metadata: + self.logger.info(f'Metadata: {json.dumps(command.metadata)}') queue = self._create_queue() if queue: diff --git a/src/attackmate/executors/metasploit/msfexecutor.py b/src/attackmate/executors/metasploit/msfexecutor.py index 6c7f9ed8..89dc516a 100644 --- a/src/attackmate/executors/metasploit/msfexecutor.py +++ b/src/attackmate/executors/metasploit/msfexecutor.py @@ -88,6 +88,7 @@ def prepare_exploit(self, command: MsfModuleCommand): for option, setting in command.options.items(): if setting.isnumeric(): exploit[option] = int(setting) + continue if setting.lower() in ['true', 'false', '1', '0', 'y', 'n', 'yes', 'no']: exploit[option] = CmdVars.variable_to_bool(option, setting) else: diff --git a/src/attackmate/executors/remote/remoteexecutor.py b/src/attackmate/executors/remote/remoteexecutor.py new file mode 100644 index 00000000..cb3df121 --- /dev/null +++ b/src/attackmate/executors/remote/remoteexecutor.py @@ -0,0 +1,173 @@ +import logging +import json +from typing import Dict, Any, Optional + +from attackmate.executors.executor_factory import executor_factory + +from attackmate.remote_client import RemoteAttackMateClient +from attackmate.result import Result +from attackmate.execexception import ExecException +from attackmate.schemas.remote import AttackMateRemoteCommand +from attackmate.executors.baseexecutor import BaseExecutor +from attackmate.processmanager import ProcessManager +from attackmate.variablestore import VariableStore + +output_logger = logging.getLogger('output') + + +@executor_factory.register_executor('remote') +class RemoteExecutor(BaseExecutor): + def __init__(self, pm: ProcessManager, varstore: VariableStore, cmdconfig=None): + super().__init__(pm, varstore, cmdconfig) + self.logger = logging.getLogger('playbook') + # Client class is instantiated per command execution with server_url context and chached in + # client_cache + self._clients_cache: Dict[str, RemoteAttackMateClient] = {} + + def log_command(self, command: AttackMateRemoteCommand): + self.logger.info( + f"Executing REMOTE AttackMate command: Type='{command.type}', " + f"RemoteCmd='{command.cmd}' on server {command.server_url}'" + ) + remote_command_json = ( + command.remote_command.model_dump() if command.remote_command else ' ' + ) + output_logger.info( + f"Remote Command'{remote_command_json}' sent to server {command.server_url}'" + ) + + def _exec_cmd(self, command: AttackMateRemoteCommand) -> Result: + try: + client = self._get_remote_client(command) + response_data = self._dispatch_remote_command(client, command) + success, error_msg, stdout, return_code = self._process_response(response_data) + + except (ExecException, IOError, FileNotFoundError) as e: + self.logger.error(f"Execution failed: {e}", exc_info=True) + success, error_msg, stdout, return_code = False, str(e), None, 1 + + except Exception as e: + error_message = f"Remote executor encountered an unexpected error: {e}" + self.logger.error(error_message, exc_info=True) + success, error_msg, stdout, return_code = False, error_message, None, 1 + + final_stdout = self._format_output(success, stdout, error_msg) + final_return_code = return_code if return_code is not None else (0 if success else 1) + + return Result(final_stdout, final_return_code) + + def _get_remote_client(self, command_config: AttackMateRemoteCommand) -> RemoteAttackMateClient: + """Gets or creates a client instance for the given server URL.""" + server_url = self.varstore.substitute(command_config.server_url) + if server_url in self._clients_cache: + return self._clients_cache[server_url] + else: + self.logger.info( + f"Creating new remote client for server: {server_url}" + ) + new_remote_client = self._create_remote_client(command_config) + self._clients_cache[server_url] = new_remote_client + return self._clients_cache[server_url] + + def _create_remote_client(self, command_config: AttackMateRemoteCommand) -> RemoteAttackMateClient: + """ + Creates and configures a new RemoteAttackMateClient + """ + server_url = self.varstore.substitute(command_config.server_url) + username = self.varstore.substitute(command_config.user) if command_config.user else None + password = ( + self.varstore.substitute(command_config.password) + if command_config.password else None + ) + cacert = self.varstore.substitute(command_config.cacert) if command_config.cacert else None + return RemoteAttackMateClient( + server_url=server_url, + username=username, + password=password, # noqa: E501 + cacert=cacert + ) + + def _dispatch_remote_command( + self, client: 'RemoteAttackMateClient', command: AttackMateRemoteCommand + ) -> Dict[str, Any]: + """ + Dispatches the command to the appropriate client method. + """ + debug = getattr(command, 'debug', False) + self.logger.debug(f"Dispatching command '{command.cmd}' with debug={debug}") + + if command.cmd == 'execute_playbook_yaml' and command.playbook_yaml_content: + with open(command.playbook_yaml_content, 'r') as f: + yaml_content = f.read() + response = client.execute_remote_playbook_yaml(yaml_content, debug=debug) + + elif command.cmd == 'execute_playbook_file' and command.playbook_file_path: + response = client.execute_remote_playbook_file(command.playbook_file_path, debug=debug) + + elif command.cmd == 'execute_command': + response = client.execute_remote_command(command.remote_command, debug=debug) + + else: + raise ExecException(f"Unsupported remote command: '{command.cmd}'") + + return response if response is not None else {} + + def _process_response(self, response_data: Optional[Dict[str, Any]]) -> tuple: + """ + Processes the raw response from the remote client to determine success, + error message, return code, and stdout string. + """ + success: bool = False + error_message: Optional[str] = None + stdout_str: Optional[str] = None + return_code: int = 1 + + if not response_data: + error_message = 'No response received from remote server (client communication failed).' + self.logger.error(error_message) + return success, error_message, stdout_str, return_code + + self.logger.debug(f"Processing response data: {json.dumps(response_data)}") + + # Prioritize 'result' key for command-like responses + cmd_result = response_data.get('result', {}) + if cmd_result: + success = cmd_result.get('success', False) + stdout_str = cmd_result.get('stdout') + return_code = cmd_result.get('returncode', 1 if not success else 0) + if not success and 'error_message' in cmd_result: + error_message = cmd_result['error_message'] + self.logger.error(f"Remote command reported error: {error_message}") + else: + self.logger.info(f"Remote command execution success: {success}, return code: {return_code}") + + # Fallback to 'success' key for playbook-like responses + elif 'success' in response_data: + success = response_data.get('success', False) + stdout_str = json.dumps(response_data, indent=2) + return_code = 0 if success else 1 + if not success: + error_message = response_data.get('message', 'Unknown error during playbook execution.') + self.logger.error(f"Remote playbook execution failed: {error_message}") + else: + self.logger.info(f"Remote playbook execution success: {success}") + + # Catch all for unexpected response structures + else: + error_message = 'Received unexpected response structure from remote server.' + stdout_str = json.dumps(response_data, indent=2) + self.logger.warning(f"{error_message}: {stdout_str}") + + return success, error_message, stdout_str, return_code + + def _format_output(self, success: bool, stdout: Optional[str], error: Optional[str]) -> str: + """Creates the final stdout string based on the execution result.""" + if error: + # Prepend the error to the standard output if both exist + header = f"Error: {error}" + return f"{header}\n\nOutput/Response:\n{stdout}" if stdout else header + + if stdout is not None: + return stdout + + return 'Operation completed successfully.' if success else 'Operation failed with no output.' diff --git a/src/attackmate/logging_setup.py b/src/attackmate/logging_setup.py index 8e0fd549..1b8e2252 100644 --- a/src/attackmate/logging_setup.py +++ b/src/attackmate/logging_setup.py @@ -1,50 +1,44 @@ # logging_setup.py import logging +import sys from colorlog import ColoredFormatter +DATE_FORMAT = '%Y-%m-%d %H:%M:%S' +OUTPUT_LOG_FILE = 'output.log' +PLAYBOOK_LOG_FILE = 'attackmate.log' +JSON_LOG_FILE = 'attackmate.json' -def create_file_handler( - file_name: str, append_logs: bool, formatter: logging.Formatter -) -> logging.FileHandler: - mode = 'a' if append_logs else 'w' - file_handler = logging.FileHandler(file_name, mode=mode) - file_handler.setFormatter(formatter) - - return file_handler +PLAYBOOK_CONSOLE_FORMAT = ' %(asctime)s %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s' +API_CONSOLE_FORMAT = ' %(asctime)s %(log_color)s%(levelname)-8s%(reset)s | API | %(log_color)s%(message)s%(reset)s' +DEFAULT_FILE_FORMAT = '%(asctime)s %(levelname)s - %(message)s' +OUTPUT_FILE_FORMAT = '--- %(asctime)s %(levelname)s: ---\n\n%(message)s' def initialize_output_logger(debug: bool, append_logs: bool): output_logger = logging.getLogger('output') - if debug: - output_logger.setLevel(logging.DEBUG) - else: - output_logger.setLevel(logging.INFO) - formatter = logging.Formatter( - '--- %(asctime)s %(levelname)s: ---\n\n%(message)s', datefmt='%Y-%m-%d %H:%M:%S' - ) - file_handler = create_file_handler('output.log', append_logs, formatter) + output_logger.setLevel(logging.DEBUG if debug else logging.INFO) + formatter = logging.Formatter(OUTPUT_FILE_FORMAT, datefmt=DATE_FORMAT) + file_handler = create_file_handler(OUTPUT_LOG_FILE, append_logs, formatter) output_logger.addHandler(file_handler) def initialize_logger(debug: bool, append_logs: bool): playbook_logger = logging.getLogger('playbook') - if debug: - playbook_logger.setLevel(logging.DEBUG) - else: - playbook_logger.setLevel(logging.INFO) + playbook_logger.setLevel(logging.DEBUG if debug else logging.INFO) # output to console - console_handler = logging.StreamHandler() - LOGFORMAT = ' %(asctime)s %(log_color)s%(levelname)-8s%(reset)s' '| %(log_color)s%(message)s%(reset)s' - formatter = ColoredFormatter(LOGFORMAT, datefmt='%Y-%m-%d %H:%M:%S') - console_handler.setFormatter(formatter) + if not has_stdout_handler(playbook_logger): + console_handler = logging.StreamHandler(sys.stdout) # Explicitly target stdout + console_formatter = ColoredFormatter(PLAYBOOK_CONSOLE_FORMAT, datefmt=DATE_FORMAT) + console_handler.setFormatter(console_formatter) + playbook_logger.addHandler(console_handler) # plain text output - playbook_logger.addHandler(console_handler) - formatter = logging.Formatter('%(asctime)s %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') - file_handler = create_file_handler('attackmate.log', append_logs, formatter) + file_formatter = logging.Formatter(DEFAULT_FILE_FORMAT, datefmt=DATE_FORMAT) + file_handler = create_file_handler(PLAYBOOK_LOG_FILE, append_logs, file_formatter) playbook_logger.addHandler(file_handler) + playbook_logger.propagate = False return playbook_logger @@ -55,7 +49,49 @@ def initialize_json_logger(json: bool, append_logs: bool): json_logger = logging.getLogger('json') json_logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(message)s') - file_handler = create_file_handler('attackmate.json', append_logs, formatter) + file_handler = create_file_handler(JSON_LOG_FILE, append_logs, formatter) json_logger.addHandler(file_handler) return json_logger + + +def initialize_api_logger(debug: bool, append_logs: bool): + api_logger = logging.getLogger('attackmate_api') + api_logger.setLevel(logging.DEBUG if debug else logging.INFO) + + # Console handler for API logs + if not has_stdout_handler(api_logger): + console_handler = logging.StreamHandler(sys.stdout) + formatter = ColoredFormatter(API_CONSOLE_FORMAT, datefmt=DATE_FORMAT) + console_handler.setFormatter(formatter) + api_logger.addHandler(console_handler) + + # File handler for API logs ? + # api_file_formatter = logging.Formatter(API_CONSOLE_FORMAT, datefmt=DATE_FORMAT) + # api_file_handler = create_file_handler('attackmate_api.log', append_logs, api_file_formatter) + # api_logger.addHandler(api_file_handler) + + # Prevent propagation to avoid duplicate logs if root logger also has handlers + api_logger.propagate = False + + return api_logger + + +def create_file_handler( + file_name: str, append_logs: bool, formatter: logging.Formatter +) -> logging.FileHandler: + mode = 'a' if append_logs else 'w' + file_handler = logging.FileHandler(file_name, mode=mode) + file_handler.setFormatter(formatter) + return file_handler + + +def has_stdout_handler(logger: logging.Logger) -> bool: + """ + Checks if a logger already has a StreamHandler directed to stdout. + + """ + return any( + isinstance(handler, logging.StreamHandler) and handler.stream == sys.stdout + for handler in logger.handlers + ) diff --git a/src/attackmate/playbook_parser.py b/src/attackmate/playbook_parser.py index 5c128a3e..c50883ef 100644 --- a/src/attackmate/playbook_parser.py +++ b/src/attackmate/playbook_parser.py @@ -118,7 +118,7 @@ def parse_playbook(playbook_file: str, logger: logging.Logger) -> Playbook: target_file = default_playbook_location / playbook_file_path else: logger.error( - f"Error: Playbook file not found under '/non/existent/path/playbook.yml' or in the current directory or in /etc/attackmate/playbooks" + "Error: Playbook file not found under '/non/existent/path/playbook.yml' or in the current directory or in /etc/attackmate/playbooks" ) exit(1) @@ -130,13 +130,27 @@ def parse_playbook(playbook_file: str, logger: logging.Logger) -> Playbook: try: with open(target_file) as f: - pb_yaml = yaml.safe_load(f) - playbook_object = Playbook.model_validate(pb_yaml) + playbook_yaml = yaml.safe_load(f) + playbook_object = Playbook.model_validate(playbook_yaml) return playbook_object except OSError: logger.error(f'Error: Could not open playbook file {target_file}') exit(1) - except ValidationError: + except ValidationError as e: logger.error(f'A Validation error occured when parsing playbook file {playbook_file}') + for error in e.errors(): + if error['type'] == 'missing': + logger.error( + f'Missing field in {error["loc"][-2]} command: {error["loc"][-1]} - {error["msg"]}' + ) + elif error['type'] == 'literal_error': + logger.error( + f'Invalid value in {error["loc"][-2]} command: {error["loc"][-1]} - {error["msg"]}' + ) + elif error['type'] == 'value_error': + logger.error( + f'Value error in command {int(error["loc"][-2]) + 1}: ' + f'{error["loc"][-1]} - {error["msg"]}' + ) logger.error(traceback.format_exc()) exit(1) diff --git a/src/attackmate/remote_client.py b/src/attackmate/remote_client.py new file mode 100644 index 00000000..3f8be9e9 --- /dev/null +++ b/src/attackmate/remote_client.py @@ -0,0 +1,196 @@ +import httpx +import logging +import os +import json +from typing import Dict, Any, Optional + +_active_sessions: Dict[str, Dict[str, str]] = {} +DEFAULT_TIMEOUT = 60.0 # what should the timeout be for requests? what about background? +# make timeout configurable? + +logger = logging.getLogger('playbook') + + +class RemoteAttackMateClient: + """ + Client to interact with a remote AttackMate REST API. + Handles authentication and token management internally per server URL. + """ + def __init__( + self, + server_url: str, + cacert: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + timeout: float = DEFAULT_TIMEOUT, + ): + self.server_url = server_url.rstrip('/') + self.username = username + self.password = password + self.timeout_config = httpx.Timeout(10.0, connect=5.0, read=timeout) + + if cacert: + if os.path.exists(cacert): + self.verify_ssl = cacert + logger.info(f"Client will verify {self.server_url} SSL using CA: {cacert}") + else: + logger.error(f"CA certificate file not found: {cacert}.") + logger.debug(f"RemoteClient initialized for {self.server_url}") + + def _get_session_token(self) -> Optional[str]: + """Retrieves a valid token for the server_url from memory, logs in if necessary.""" + # There is a token for that server and user in memory, use that + session_data = _active_sessions.get(self.server_url) + if session_data and session_data.get('user') == self.username: + logger.debug(f"Using existing token for {self.server_url} by user {session_data['user']}") + return session_data['token'] + # if not try login with credentials + else: + if self.username and self.password: + return self._login(self.username, self.password) + return None + + def _login(self, username: str, password: str) -> Optional[str]: + """Internal login method, stores token.""" + login_url = f"{self.server_url}/login" + logger.info(f"Attempting login to {login_url} for user '{username}'...") + try: + with httpx.Client(verify=self.verify_ssl, timeout=self.timeout_config) as client: + response = client.post(login_url, data={'username': username, 'password': password}) + # does this need to be form data? + + response.raise_for_status() + data = response.json() + token = data.get('access_token') + + if token: + # with a session_lock? + _active_sessions[self.server_url] = { + 'token': token, + 'user': username + } + logger.info(f"Login successful for '{username}' at {self.server_url}. Token stored.") + return token + else: + logger.error(f"Login to {self.server_url} succeeded but no token received.") + return None + except httpx.HTTPStatusError as e: + logger.error(f"Login failed for '{username}' at {self.server_url}: {e.response.status_code}") + return None + except Exception as e: + logger.error(f"Login request to {self.server_url} failed: {e}", exc_info=True) + return None + + # Common Method to make request + def _make_request( + self, + method: str, + endpoint: str, + json_data: Optional[Dict[str, Any]] = None, + content_data: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + """Makes an authenticated request, handles token renewal implicitly by server.""" + token = self._get_session_token() + if not token: + # Attempt login if credentials are set on the client instance + if self.username and self.password: + logger.info( + f"No active token for {self.server_url}, try login with provided credentials." + ) + token = self._login(self.username, self.password) + if not token: + logger.error(f"Auth required for {self.server_url} but no token available and login failed") + return None # Or raise an AuthException? + + headers = {'X-Auth-Token': token} + if content_data: + headers['Content-Type'] = 'application/yaml' + + url = f"{self.server_url}/{endpoint.lstrip('/')}" + logger.debug(f"Making {method.upper()} request to {url}") + try: + with httpx.Client(verify=self.verify_ssl, timeout=self.timeout_config) as client: + if method.upper() == 'POST': + if content_data: + # sending yaml playbook content + response = client.post(url, content=content_data, headers=headers, params=params) + else: + # sending command or file path + response = client.post(url, json=json_data, headers=headers, params=params) + elif method.upper() == 'GET': + response = client.get(url, headers=headers, params=params) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + response_data = response.json() + + # Server should renew token on any successful authenticated API call. + # Client just uses the token. Server sends back a new token if renewed, + # The current_token field in the response from server use to update client's token. + new_token_from_response = response_data.get('current_token') + if new_token_from_response and new_token_from_response != token: + logger.info(f"Server returned a renewed token for {self.server_url}. Updating client.") + _active_sessions[self.server_url]['token'] = new_token_from_response + return response_data + + except httpx.HTTPStatusError as e: + logger.error(f"API Error ({method} {url}): {e.response.status_code}") + if e.response.status_code == 401: # Unauthorized + logger.warning(f"Token likely expired or invalid for {self.server_url}. Clearing token") + # with _session_lock: + _active_sessions.pop(self.server_url, None) # Clear session on 401 + return None # Or raise custom Error? + except httpx.RequestError as e: + logger.error(f"Request Error ({method} {url}): {e}") + return None + except json.JSONDecodeError: + logger.error(f"JSON Decode Error ({method} {url}). Response: {response.text}") + return None + except Exception as e: + logger.error(f"Unexpected error during API request ({method} {url}): {e}", exc_info=True) + return None + + # API Methods all use the _make_request method + def execute_remote_playbook_yaml( + self, playbook_yaml_content: str, debug: bool = False + ) -> Optional[Dict[str, Any]]: + return self._make_request( + method='POST', + endpoint='playbooks/execute/yaml', + content_data=playbook_yaml_content, + params={'debug': True} if debug else None + ) + + def execute_remote_playbook_file( + self, server_playbook_path: str, debug: bool = False + ) -> Optional[Dict[str, Any]]: + return self._make_request( + method='POST', + endpoint='playbooks/execute/file', + json_data={'file_path': server_playbook_path}, + params={'debug': True} if debug else None + ) + + def execute_remote_command( + self, + command_pydantic_model, + debug: bool = False + ) -> Optional[Dict[str, Any]]: + # get the correct enpoint + endpoint = 'command/execute' + + # Convert Pydantic model to dict for JSON body + # handle None values for optional fields (exclude_none=True) + command_body_dict = command_pydantic_model.model_dump(exclude_none=True) + request_payload = { + 'command': command_body_dict + } + + return self._make_request( + method='POST', + endpoint=endpoint, + json_data=request_payload, + params={'debug': True} if debug else None + ) diff --git a/src/attackmate/result.py b/src/attackmate/result.py index 6675d0f3..3fde98ba 100644 --- a/src/attackmate/result.py +++ b/src/attackmate/result.py @@ -26,5 +26,4 @@ def __init__(self, stdout, returncode): self.returncode = returncode def __repr__(self): - return f"Result(stdout={repr(self.stdout)}, returncode={self.returncode})" - + return f"Result(stdout={repr(self.stdout)}, returncode={self.returncode})" diff --git a/src/attackmate/schemas/command_subtypes.py b/src/attackmate/schemas/command_subtypes.py new file mode 100644 index 00000000..f523bae0 --- /dev/null +++ b/src/attackmate/schemas/command_subtypes.py @@ -0,0 +1,83 @@ +from __future__ import annotations +from typing import Annotated, TypeAlias, Union +from pydantic import Field +# Core Commands +from .sleep import SleepCommand +from .shell import ShellCommand +from .setvar import SetVarCommand +from .include import IncludeCommand +from .loop import LoopCommand +from .http import WebServCommand, HttpClientCommand +from .father import FatherCommand +from .tempfile import TempfileCommand +from .debug import DebugCommand +from .regex import RegExCommand +from .vnc import VncCommand +from .json import JsonCommand +from .browser import BrowserCommand +from .ssh import SSHCommand, SFTPCommand +# Metasploit Commands +from .metasploit import MsfModuleCommand, MsfSessionCommand, MsfPayloadCommand +# Sliver Commands +from .sliver import ( + SliverSessionCDCommand, + SliverSessionLSCommand, + SliverSessionNETSTATCommand, + SliverSessionEXECCommand, + SliverSessionMKDIRCommand, + SliverSessionSimpleCommand, + SliverSessionDOWNLOADCommand, + SliverSessionUPLOADCommand, + SliverSessionPROCDUMPCommand, + SliverSessionRMCommand, + SliverSessionTERMINATECommand, + SliverHttpsListenerCommand, + SliverGenerateCommand, +) + +SliverSessionCommands: TypeAlias = Annotated[Union[ + SliverSessionCDCommand, + SliverSessionLSCommand, + SliverSessionNETSTATCommand, + SliverSessionEXECCommand, + SliverSessionMKDIRCommand, + SliverSessionSimpleCommand, + SliverSessionDOWNLOADCommand, + SliverSessionUPLOADCommand, + SliverSessionPROCDUMPCommand, + SliverSessionRMCommand, + SliverSessionTERMINATECommand], Field(discriminator='cmd')] + + +SliverCommands: TypeAlias = Annotated[Union[ + SliverHttpsListenerCommand, + SliverGenerateCommand], Field(discriminator='cmd')] + + +# This excludes the AttackMateRemoteCommand type +RemotelyExecutableCommand: TypeAlias = Annotated[ + Union[ + SliverSessionCommands, + SliverCommands, + BrowserCommand, + ShellCommand, + MsfModuleCommand, + MsfSessionCommand, + MsfPayloadCommand, + SleepCommand, + SSHCommand, + FatherCommand, + SFTPCommand, + DebugCommand, + SetVarCommand, + RegExCommand, + TempfileCommand, + IncludeCommand, + LoopCommand, + WebServCommand, + HttpClientCommand, + JsonCommand, + VncCommand, + ], + Field(discriminator='type'), +] diff --git a/src/attackmate/schemas/command_types.py b/src/attackmate/schemas/command_types.py new file mode 100644 index 00000000..4a54728f --- /dev/null +++ b/src/attackmate/schemas/command_types.py @@ -0,0 +1,18 @@ + + +from typing import List, Annotated, TypeAlias, Union +from pydantic import Field +from attackmate.schemas.command_subtypes import RemotelyExecutableCommand +from attackmate.schemas.remote import AttackMateRemoteCommand + + +Command: TypeAlias = Annotated[ + Union[ + RemotelyExecutableCommand, + AttackMateRemoteCommand + ], + Field(discriminator='type'), +] + + +Commands: TypeAlias = List[Command] diff --git a/src/attackmate/schemas/config.py b/src/attackmate/schemas/config.py index 18079103..f21264dc 100644 --- a/src/attackmate/schemas/config.py +++ b/src/attackmate/schemas/config.py @@ -16,9 +16,10 @@ class MsfConfig(BaseModel): class CommandConfig(BaseModel): loop_sleep: int = 5 + command_delay: float = 0 class Config(BaseModel): sliver_config: SliverConfig = SliverConfig(config_file=None) msf_config: MsfConfig = MsfConfig(password=None) - cmd_config: CommandConfig = CommandConfig(loop_sleep=5) + cmd_config: CommandConfig = CommandConfig(loop_sleep=5, command_delay=0) diff --git a/src/attackmate/schemas/http.py b/src/attackmate/schemas/http.py index 1dbc741d..9cb37e22 100644 --- a/src/attackmate/schemas/http.py +++ b/src/attackmate/schemas/http.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Dict +from typing import Literal, Optional, Dict, Any from .base import BaseCommand, StringNumber from attackmate.command import CommandRegistry @@ -20,7 +20,7 @@ class HttpClientCommand(BaseCommand): output_headers: bool = False headers: Optional[Dict[str, str]] = None cookies: Optional[Dict[str, str]] = None - data: Optional[Dict[str, str]] = None + data: Optional[Dict[str, Any]] = None local_path: Optional[str] = None useragent: str = 'AttackMate' follow: bool = False diff --git a/src/attackmate/schemas/loop.py b/src/attackmate/schemas/loop.py index 77ada26b..1d8c0c77 100644 --- a/src/attackmate/schemas/loop.py +++ b/src/attackmate/schemas/loop.py @@ -1,6 +1,5 @@ from attackmate.schemas.base import BaseCommand from pydantic import field_validator -from typing import Literal from attackmate.command import CommandRegistry from typing import Literal, Union, Optional, List from .sleep import SleepCommand @@ -8,6 +7,14 @@ from .vnc import VncCommand from .setvar import SetVarCommand from .include import IncludeCommand +from .ssh import SSHCommand, SFTPCommand +from .http import WebServCommand, HttpClientCommand +from .father import FatherCommand +from .tempfile import TempfileCommand +from .debug import DebugCommand +from .regex import RegExCommand +from .browser import BrowserCommand + from .metasploit import MsfModuleCommand, MsfSessionCommand, MsfPayloadCommand from .sliver import ( @@ -25,13 +32,6 @@ SliverHttpsListenerCommand, SliverGenerateCommand, ) -from .ssh import SSHCommand, SFTPCommand -from .http import WebServCommand, HttpClientCommand -from .father import FatherCommand -from .tempfile import TempfileCommand -from .debug import DebugCommand -from .regex import RegExCommand -from .browser import BrowserCommand Commands = List[ diff --git a/src/attackmate/schemas/playbook.py b/src/attackmate/schemas/playbook.py index d7aaac95..fd333d89 100644 --- a/src/attackmate/schemas/playbook.py +++ b/src/attackmate/schemas/playbook.py @@ -1,75 +1,9 @@ -from typing import List, Optional, Dict, Union +from __future__ import annotations +from typing import List, Optional, Dict + from .base import StrInt from pydantic import BaseModel -from .sleep import SleepCommand -from .shell import ShellCommand -from .setvar import SetVarCommand -from .include import IncludeCommand -from .loop import LoopCommand -from .metasploit import MsfModuleCommand, MsfSessionCommand, MsfPayloadCommand - -from .sliver import ( - SliverSessionCDCommand, - SliverSessionLSCommand, - SliverSessionNETSTATCommand, - SliverSessionEXECCommand, - SliverSessionMKDIRCommand, - SliverSessionSimpleCommand, - SliverSessionDOWNLOADCommand, - SliverSessionUPLOADCommand, - SliverSessionPROCDUMPCommand, - SliverSessionRMCommand, - SliverSessionTERMINATECommand, - SliverHttpsListenerCommand, - SliverGenerateCommand, -) -from .ssh import SSHCommand, SFTPCommand -from .http import WebServCommand, HttpClientCommand -from .father import FatherCommand -from .tempfile import TempfileCommand -from .debug import DebugCommand -from .regex import RegExCommand -from .vnc import VncCommand -from .json import JsonCommand -from .browser import BrowserCommand - -Command = Union[ - BrowserCommand, - ShellCommand, - MsfModuleCommand, - MsfSessionCommand, - MsfPayloadCommand, - SleepCommand, - SSHCommand, - FatherCommand, - SFTPCommand, - DebugCommand, - SetVarCommand, - RegExCommand, - TempfileCommand, - IncludeCommand, - LoopCommand, - WebServCommand, - HttpClientCommand, - JsonCommand, - SliverSessionCDCommand, - SliverSessionLSCommand, - SliverSessionNETSTATCommand, - SliverSessionEXECCommand, - SliverSessionMKDIRCommand, - SliverSessionSimpleCommand, - SliverSessionDOWNLOADCommand, - SliverSessionUPLOADCommand, - SliverSessionPROCDUMPCommand, - SliverSessionRMCommand, - SliverSessionTERMINATECommand, - SliverHttpsListenerCommand, - SliverGenerateCommand, - VncCommand, -] - - -Commands = List[Command] +from .command_types import Commands class Playbook(BaseModel): diff --git a/src/attackmate/schemas/remote.py b/src/attackmate/schemas/remote.py new file mode 100644 index 00000000..c365f7c9 --- /dev/null +++ b/src/attackmate/schemas/remote.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from pydantic import model_validator +from typing import Literal, Optional + +from .base import BaseCommand +from attackmate.command import CommandRegistry + +from .command_subtypes import RemotelyExecutableCommand + + +@CommandRegistry.register('remote') +class AttackMateRemoteCommand(BaseCommand): + + type: Literal['remote'] + cmd: Literal['execute_command', 'execute_playbook_yaml', 'execute_playbook_file'] + server_url: str + cacert: str # TODO configure this file path in some configs elsewhere? + user: str + password: str + playbook_yaml_content: Optional[str] = None + playbook_file_path: Optional[str] = None + remote_command: Optional[RemotelyExecutableCommand] = None + + # Common command parameters (like background, only_if) from BaseCommand + # will be applied to the command itself, not the remote_command executed on the remote instance + + @model_validator(mode='after') + def check_remote_command(self) -> 'AttackMateRemoteCommand': + if self.cmd == 'execute_command' and not self.remote_command: + raise ValueError("remote_command must be provided when cmd is 'execute_command'") + if self.cmd == 'execute_playbook_yaml' and not self.playbook_yaml_content: + raise ValueError("playbook_yaml_content must be provided when cmd is 'execute_playbook_yaml'") + if self.cmd == 'execute_playbook_file' and not self.playbook_file_path: + raise ValueError("playbook_file_path must be provided when cmd is 'execute_playbook_file'") + return self diff --git a/test/units/test_browserexecutor.py b/test/units/test_browserexecutor.py index 271c4ee3..a642a519 100644 --- a/test/units/test_browserexecutor.py +++ b/test/units/test_browserexecutor.py @@ -1,10 +1,27 @@ import pytest from pydantic import ValidationError from unittest.mock import patch, MagicMock +from urllib.parse import quote from attackmate.executors.browser.sessionstore import BrowserSessionStore, SessionThread from attackmate.executors.browser.browserexecutor import BrowserExecutor, BrowserCommand +# Minimal, stable inline HTML (no network needed) +HTML_SIMPLE = "

Hello World

" +HTML_WITH_LINK = """ + + + + go + + + + """ + +DATA_URL_SIMPLE = "data:text/html," + quote(HTML_SIMPLE) +DATA_URL_WITH_LINK = "data:text/html," + quote(HTML_WITH_LINK) + + @pytest.fixture def mock_playwright(): """ @@ -89,7 +106,7 @@ def test_session_thread_lifecycle(mock_playwright): thread = SessionThread(session_name='test_session', headless=True) # Submit a command. This should go onto the queue, be processed, and return 'OK' - result = thread.submit_command('visit', url='http://example.org') + result = thread.submit_command('visit', url=DATA_URL_SIMPLE) assert result == 'OK', "Expected the visit command to return 'OK'" # Stop the thread @@ -149,7 +166,7 @@ def test_browser_executor_ephemeral_session(browser_executor): command = BrowserCommand( type='browser', cmd='visit', - url='http://example.org', + url=DATA_URL_SIMPLE, headless=True ) result = browser_executor._exec_cmd(command) @@ -164,7 +181,7 @@ def test_browser_executor_named_session(browser_executor): create_cmd = BrowserCommand( type='browser', cmd='visit', - url='http://example.org', + url=DATA_URL_WITH_LINK, creates_session='my_session', headless=True ) @@ -176,7 +193,7 @@ def test_browser_executor_named_session(browser_executor): reuse_cmd = BrowserCommand( type='browser', cmd='click', - selector='a[href="https://www.iana.org/domains/example"]', + selector='#test-link', # matches the anchor in DATA_URL_WITH_LINK session='my_session' ) result2 = browser_executor._exec_cmd(reuse_cmd) @@ -207,7 +224,7 @@ def test_browser_executor_recreate_same_session(browser_executor): cmd1 = BrowserCommand( type='browser', cmd='visit', - url='http://example.org', + url=DATA_URL_SIMPLE, creates_session='my_session', headless=True ) @@ -217,7 +234,7 @@ def test_browser_executor_recreate_same_session(browser_executor): cmd2 = BrowserCommand( type='browser', cmd='visit', - url='http://example.com', + url=DATA_URL_WITH_LINK, creates_session='my_session', headless=True ) @@ -236,5 +253,5 @@ def test_browser_executor_unknown_command_validation(): BrowserCommand( type='browser', cmd='zoom', # invalid literal, should be one of [visit, click, type, screenshot] - url='http://example.org' + url=DATA_URL_SIMPLE ) diff --git a/test/units/test_commanddelay.py b/test/units/test_commanddelay.py new file mode 100644 index 00000000..228ecdc7 --- /dev/null +++ b/test/units/test_commanddelay.py @@ -0,0 +1,91 @@ +import time +from attackmate.attackmate import AttackMate +from attackmate.schemas.config import Config, CommandConfig +from attackmate.schemas.playbook import Playbook +from attackmate.schemas.debug import DebugCommand +from attackmate.schemas.setvar import SetVarCommand +from attackmate.schemas.shell import ShellCommand +from attackmate.schemas.sleep import SleepCommand + + +def test_command_delay_is_applied(): + """ + Tests that command_delay is applied between applicable commands. + """ + delay = 0.2 + num_commands = 3 + playbook = Playbook( + commands=[ + ShellCommand(type='shell', cmd='echo 1'), + ShellCommand(type='shell', cmd='echo 2'), + ShellCommand(type='shell', cmd='echo 3'), + ] + ) + config = Config(cmd_config=CommandConfig(command_delay=delay)) + attackmate_instance = AttackMate(playbook=playbook, config=config) + + start_time = time.monotonic() + attackmate_instance._run_commands(attackmate_instance.playbook.commands) + end_time = time.monotonic() + elapsed_time = end_time - start_time + + expected_minimum_time = num_commands * delay + # Allow for command execution overhead + expected_maximum_time = expected_minimum_time + 0.5 + + assert elapsed_time >= expected_minimum_time, ( + f"Execution faster ({elapsed_time:.4f}s) than the minimum expected delay " + f"({expected_minimum_time:.4f}s)." + ) + assert elapsed_time < expected_maximum_time, ( + f"Execution slower ({elapsed_time:.4f}s) than expected." + ) + + +def test_zero_command_delay(): + """ + Tests that no delay is applied when command_delay is 0. + """ + playbook = Playbook( + commands=[ + ShellCommand(type='shell', cmd='echo 1'), + ShellCommand(type='shell', cmd='echo 2'), + ] + ) + config = Config(cmd_config=CommandConfig(command_delay=0)) + attackmate_instance = AttackMate(playbook=playbook, config=config) + + start_time = time.monotonic() + attackmate_instance._run_commands(attackmate_instance.playbook.commands) + end_time = time.monotonic() + elapsed_time = end_time - start_time + + # With no delay, execution should be very fast. + assert elapsed_time < 0.1, ( + f"Execution with no delay took too long: {elapsed_time:.4f}s." + ) + + +def test_delay_is_not_applied_for_exempt_commands(): + """ + Tests that delay is skipped for 'sleep', 'debug', and 'setvar' commands. + """ + playbook = Playbook( + commands=[ + SetVarCommand(type='setvar', cmd='x', variable='y'), + DebugCommand(type='debug', cmd='test message'), + SleepCommand(type='sleep', seconds=0), + ] + ) + # This delay should be ignored + config = Config(cmd_config=CommandConfig(command_delay=5)) + attackmate_instance = AttackMate(playbook=playbook, config=config) + + start_time = time.monotonic() + attackmate_instance._run_commands(attackmate_instance.playbook.commands) + end_time = time.monotonic() + elapsed_time = end_time - start_time + + assert elapsed_time < 0.1, ( + f"Execution with exempt commands took too long: {elapsed_time:.4f}s." + ) diff --git a/test/units/test_logging.py b/test/units/test_logging.py index a7ba684e..517fe683 100644 --- a/test/units/test_logging.py +++ b/test/units/test_logging.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import patch, MagicMock from attackmate.logging_setup import create_file_handler +from attackmate.logging_setup import initialize_json_logger +import logging @patch('attackmate.logging_setup.logging.FileHandler') @@ -31,3 +32,39 @@ def test_create_file_handler_write_mode(MockFileHandler): MockFileHandler.assert_called_with(file_name, mode='w') mock_handler.setFormatter.assert_called_with(formatter) + + +@patch('attackmate.logging_setup.logging.getLogger') +@patch('attackmate.logging_setup.logging.Formatter') +@patch('attackmate.logging_setup.create_file_handler') +def test_initialize_json_logger_when_enabled(mock_create_handler, mock_formatter, mock_get_logger): + """ + Tests that the JSON logger is configured correctly when enabled. + """ + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + mock_file_handler = MagicMock() + mock_create_handler.return_value = mock_file_handler + mock_formatter_instance = MagicMock() + mock_formatter.return_value = mock_formatter_instance + + json_logger = initialize_json_logger(json=True, append_logs=True) + + mock_get_logger.assert_called_once_with('json') + mock_logger.setLevel.assert_called_once_with(logging.DEBUG) + + mock_create_handler.assert_called_once_with('attackmate.json', True, mock_formatter_instance) + mock_logger.addHandler.assert_called_once_with(mock_file_handler) + + +@patch('attackmate.logging_setup.logging.getLogger') +@patch('attackmate.logging_setup.create_file_handler') +def test_initialize_json_logger_when_disabled(mock_create_handler, mock_get_logger): + """ + Tests that the JSON logger is not created and returns None when disabled. + """ + returned_value = initialize_json_logger(json=False, append_logs=True) + + assert returned_value is None + mock_get_logger.assert_not_called() + mock_create_handler.assert_not_called()