diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 7d1a7b655..a9de450c0 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -17,7 +17,7 @@ jobs: permissions: read-all strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: # Print out details about the run - name: Dump GitHub context @@ -178,18 +178,11 @@ jobs: python -m pip install -r requirements-dev.txt else echo "Missing requirements-dev.txt. Installing minimal requirements for testing." - python -m pip install flake8 black bandit mypy pylint types-attrs pydocstyle pyroma + python -m pip install bandit mypy types-attrs pyroma fi - - name: black + - name: ruff run: | - black --diff --check --exclude venv msticpy - if: ${{ always() }} - - name: flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 msticpy --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 --max-line-length=90 --exclude=tests* . --ignore=E501,W503 --jobs=auto + ruff check msticpy --ignore PLW0603 if: ${{ always() }} - name: pylint run: | @@ -221,14 +214,6 @@ jobs: run: | bandit -x tests -r -s B303,B404,B603,B607,B608,B113 msticpy if: ${{ always() }} - - name: flake8 - run: | - flake8 --max-line-length=90 --exclude=tests* . --ignore=E501,W503 --jobs=auto - if: ${{ always() }} - - name: pydocstyle - run: | - pydocstyle --convention=numpy msticpy - if: ${{ always() }} - name: pyroma run: | pyroma --min 10 . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 928742cd9..752744fd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,11 +7,6 @@ repos: exclude: .*devcontainer.json - id: trailing-whitespace args: [--markdown-linebreak-ext=md] - - repo: https://github.com/ambv/black - rev: 25.9.0 - hooks: - - id: black - language: python - repo: https://github.com/PyCQA/pylint rev: v4.0.2 hooks: @@ -19,22 +14,6 @@ repos: args: - --disable=duplicate-code,import-error - --ignore-patterns=test_ - - repo: https://github.com/pycqa/flake8 - rev: 7.3.0 - hooks: - - id: flake8 - args: - - --extend-ignore=E401,E501,W503 - - --max-line-length=90 - - --exclude=tests,test*.py - - repo: https://github.com/pycqa/isort - rev: 7.0.0 - hooks: - - id: isort - name: isort (python) - args: - - --profile - - black - repo: https://github.com/pycqa/pydocstyle rev: 6.3.0 hooks: @@ -43,14 +22,17 @@ repos: - --convention=numpy - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.14.1 + rev: v0.14.8 hooks: # Run the linter. - - id: ruff - types_or: [ python, pyi, jupyter ] + - id: ruff-check args: - msticpy - --fix + # Run the formatter. + - id: ruff-format + args: + - msticpy - repo: local hooks: - id: check_reqs_all diff --git a/README.md b/README.md index 7f7e71105..40fb20a30 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ ![GitHub Actions build](https://github.com/microsoft/msticpy/actions/workflows/python-package.yml/badge.svg?branch=main) [![Azure Pipelines build](https://dev.azure.com/mstic-detections/mstic-jupyter/_apis/build/status/microsoft.msticpy?branchName=main)](https://dev.azure.com/mstic-detections/mstic-jupyter/_build/latest?definitionId=14&branchName=main) [![Downloads](https://pepy.tech/badge/msticpy)](https://pepy.tech/project/msticpy) -[![BlackHat Arsenal 2020](https://raw.githubusercontent.com/toolswatch/badges/master/arsenal/usa/2020.svg)](https://www.blackhat.com/us-20/arsenal/schedule/#msticpy-the-security-analysis-swiss-army-knife-19872) Microsoft Threat Intelligence Python Security Tools. @@ -29,7 +28,7 @@ alt="Timeline" title="Msticpy Timeline Control" height="300" /> The **msticpy** package was initially developed to support [Jupyter Notebooks](https://jupyter-notebook-beginner-guide.readthedocs.io/en/latest/) authoring for -[Azure Sentinel](https://azure.microsoft.com/en-us/services/azure-sentinel/). +[Microsoft Sentinel](https://www.microsoft.com/en-us/security/business/siem-and-xdr/microsoft-sentinel/). While Azure Sentinel is still a big focus of our work, we are extending the data query/acquisition components to pull log data from other sources (currently Splunk, Microsoft Defender for Endpoint and @@ -55,11 +54,6 @@ For core install: `pip install msticpy` -If you are using *MSTICPy* with Azure Sentinel you should install with -the "azsentinel" extra package: - -`pip install msticpy[azsentinel]` - or for the latest dev build `pip install git+https://github.com/microsoft/msticpy` @@ -90,8 +84,8 @@ functions in this interactive demo on mybinder.org. ## Log Data Acquisition -QueryProvider is an extensible query library targeting Azure Sentinel/Log Analytics, -Splunk, OData +QueryProvider is an extensible query library targeting Microsoft Sentinel/Log Analytics, +Microsoft XDR, Splunk, OData and other log data sources. It also has special support for [Mordor](https://github.com/OTRF/mordor) data sets and using local data. @@ -325,7 +319,7 @@ See the following notebooks for more examples of the use of this package in prac ## Supported Platforms and Packages - msticpy is OS-independent -- Requires [Python 3.8 or later](https://www.python.org/dev/peps/pep-0494/) +- Requires [Python 3.10 or later](https://www.python.org/dev/peps/pep-0494/) - See [requirements.txt](requirements.txt) for more details and version requirements. --- diff --git a/conda/conda-reqs-dev.txt b/conda/conda-reqs-dev.txt index 8dc758ce9..d80b398be 100644 --- a/conda/conda-reqs-dev.txt +++ b/conda/conda-reqs-dev.txt @@ -4,10 +4,7 @@ beautifulsoup4 black>=20.8b1 coverage>=5.5 filelock>=3.0.0 -flake8>=3.8.4 -isort>=5.10.1 markdown>=3.3.4 -mccabe>=0.6.1 mypy>=0.821 nbconvert>=6.1.0 nbdime>=2.1.0 @@ -15,7 +12,6 @@ pandas>=1.4.0 pep8-naming>=0.10.0 pep8>=1.7.1 pipreqs>=0.4.9 -pycodestyle>=2.6.0 pydocstyle>=6.0.0 pyflakes>=2.2.0 pylint>=2.5.3 diff --git a/conda/conda-reqs-pip.txt b/conda/conda-reqs-pip.txt index eddd19ca6..ab3048b3c 100644 --- a/conda/conda-reqs-pip.txt +++ b/conda/conda-reqs-pip.txt @@ -1,6 +1,5 @@ azure-mgmt-resourcegraph>=8.0.0 azure-monitor-query>=1.0.0, <=2.0.0 -# KqlmagicCustom[jupyter-basic,auth_code_clipboard]>=0.1.114.post22 mo-sql-parsing>=11, <12.0.0 nest_asyncio>=1.4.0 passivetotal>=2.5.3 @@ -8,6 +7,5 @@ sumologic-sdk>=0.1.11 splunk-sdk>=1.6.0,!=2.0.0 packaging>=24.0 requests>=2.31.0 -importlib-resources >= 6.4.0; python_version <= "3.8" rrcf==0.4.4 joblib>=1.3.0 diff --git a/docs/source/api/msticpy.data.drivers.prismacloud_driver.rst b/docs/source/api/msticpy.data.drivers.prismacloud_driver.rst new file mode 100644 index 000000000..de1e0e5e8 --- /dev/null +++ b/docs/source/api/msticpy.data.drivers.prismacloud_driver.rst @@ -0,0 +1,7 @@ +msticpy.data.drivers.prismacloud\_driver module +=============================================== + +.. automodule:: msticpy.data.drivers.prismacloud_driver + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/msticpy.data.drivers.rst b/docs/source/api/msticpy.data.drivers.rst index dd85fac7d..7248102e4 100644 --- a/docs/source/api/msticpy.data.drivers.rst +++ b/docs/source/api/msticpy.data.drivers.rst @@ -26,6 +26,7 @@ Submodules msticpy.data.drivers.mdatp_driver msticpy.data.drivers.mordor_driver msticpy.data.drivers.odata_driver + msticpy.data.drivers.prismacloud_driver msticpy.data.drivers.resource_graph_driver msticpy.data.drivers.security_graph_driver msticpy.data.drivers.sentinel_query_reader diff --git a/docs/source/conf.py b/docs/source/conf.py index b41087ea0..f27de895c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -257,7 +257,6 @@ "ipywidgets", "jwt", "keyring", - "Kqlmagic", "matplotlib.pyplot", "matplotlib", "mo_sql_parsing", diff --git a/msticpy/__init__.py b/msticpy/__init__.py index b9b6fae23..3030c573d 100644 --- a/msticpy/__init__.py +++ b/msticpy/__init__.py @@ -114,8 +114,8 @@ initialization and checks are performed. """ -import os -from typing import Iterable, Union + +from collections.abc import Iterable from . import nbwidgets # noqa: F401 @@ -134,9 +134,6 @@ get_config = settings.get_config setup_logging() -if not os.environ.get("KQLMAGIC_EXTRAS_REQUIRES"): - os.environ["KQLMAGIC_EXTRAS_REQUIRES"] = "jupyter-basic" - _LAZY_IMPORTS = { "msticpy.auth.azure_auth.az_connect", "msticpy.common.timespan.TimeSpan", @@ -159,7 +156,7 @@ module, __getattr__, __dir__ = lazy_import(__name__, _LAZY_IMPORTS) -def load_plugins(plugin_paths: Union[str, Iterable[str]]): +def load_plugins(plugin_paths: str | Iterable[str]): """ Load plugins from specified paths or configuration. @@ -177,6 +174,6 @@ def load_plugins(plugin_paths: Union[str, Iterable[str]]): """ # pylint: disable=import-outside-toplevel - from .init.mp_plugins import read_plugins + from .init.mp_plugins import read_plugins # noqa: PLC0415 read_plugins(plugin_paths) diff --git a/msticpy/_version.py b/msticpy/_version.py index 76d658835..aa42e781e 100644 --- a/msticpy/_version.py +++ b/msticpy/_version.py @@ -1,3 +1,3 @@ """Version file.""" -VERSION = "2.18.0" +VERSION = "3.0.0.pre1" diff --git a/msticpy/analysis/anomalous_sequence/anomalous.py b/msticpy/analysis/anomalous_sequence/anomalous.py index fb38877e1..16aabbd5e 100644 --- a/msticpy/analysis/anomalous_sequence/anomalous.py +++ b/msticpy/analysis/anomalous_sequence/anomalous.py @@ -8,6 +8,7 @@ In particular, this module is for both modelling and visualising your session data. """ + from __future__ import annotations import pandas as pd @@ -61,7 +62,7 @@ def score_sessions( raise MsticpyException(f'"{session_column}" should be a column in the `data`') sessions_df = data.copy() - sessions = sessions_df[session_column].values.tolist() # type: ignore + sessions = sessions_df[session_column].values.tolist() model = Model(sessions=sessions) model.train() @@ -69,9 +70,9 @@ def score_sessions( window_len=window_length, use_geo_mean=False, use_start_end_tokens=True ) - sessions_df[f"rarest_window{window_length}_likelihood"] = ( - model.rare_window_likelihoods[window_length] - ) + sessions_df[f"rarest_window{window_length}_likelihood"] = model.rare_window_likelihoods[ + window_length + ] sessions_df[f"rarest_window{window_length}"] = model.rare_windows[window_length] return sessions_df diff --git a/msticpy/analysis/anomalous_sequence/model.py b/msticpy/analysis/anomalous_sequence/model.py index b708f8599..a56e021c9 100644 --- a/msticpy/analysis/anomalous_sequence/model.py +++ b/msticpy/analysis/anomalous_sequence/model.py @@ -6,7 +6,6 @@ """Module for Model class for modelling sessions data.""" from collections import defaultdict -from typing import Dict, List, Union from ...common.exceptions import MsticpyException from .utils import cmds_only, cmds_params_only, cmds_params_values, probabilities @@ -18,9 +17,7 @@ class Model: """Class for modelling sessions data.""" - def __init__( - self, sessions: List[List[Union[str, Cmd]]], modellable_params: set = None - ): + def __init__(self, sessions: list[list[str | Cmd]], modellable_params: set = None): """ Instantiate the Model class. @@ -105,16 +102,16 @@ def __init__( self.value_probs = None self.value_cond_param_probs = None - self.set_params_cond_cmd_probs: Dict[str, Dict[str, float]] = {} + self.set_params_cond_cmd_probs: dict[str, dict[str, float]] = {} self.session_likelihoods = None self.session_geomean_likelihoods = None - self.rare_windows: Dict[int, list] = {} - self.rare_window_likelihoods: Dict[int, list] = {} + self.rare_windows: dict[int, list] = {} + self.rare_window_likelihoods: dict[int, list] = {} - self.rare_windows_geo: Dict[int, list] = {} - self.rare_window_likelihoods_geo: Dict[int, list] = {} + self.rare_windows_geo: dict[int, list] = {} + self.rare_window_likelihoods_geo: dict[int, list] = {} def train(self): """ @@ -154,9 +151,7 @@ def compute_scores(self, use_start_end_tokens: bool): """ if self.prior_probs is None: - raise MsticpyException( - "please train the model first before using this method" - ) + raise MsticpyException("please train the model first before using this method") self.compute_likelihoods_of_sessions(use_start_end_tokens=use_start_end_tokens) self.compute_geomean_lik_of_sessions() self.compute_rarest_windows( @@ -339,7 +334,7 @@ def _compute_probs(self): if self.session_type == SessionType.cmds_params_values: self._compute_probs_values() - def compute_setof_params_cond_cmd(self, use_geo_mean: bool): # noqa: MC0001 + def compute_setof_params_cond_cmd(self, use_geo_mean: bool): """ Compute likelihood of combinations of params conditional on the cmd. @@ -370,9 +365,7 @@ def compute_setof_params_cond_cmd(self, use_geo_mean: bool): # noqa: MC0001 """ if self.param_probs is None: - raise MsticpyException( - "please train the model first before using this method" - ) + raise MsticpyException("please train the model first before using this method") if self.session_type is None: raise MsticpyException("session_type attribute should not be None") @@ -442,9 +435,7 @@ def compute_likelihoods_of_sessions(self, use_start_end_tokens: bool = True): """ if self.prior_probs is None: - raise MsticpyException( - "please train the model first before using this method" - ) + raise MsticpyException("please train the model first before using this method") result = [] @@ -556,9 +547,7 @@ def compute_rarest_windows( """ if self.prior_probs is None: - raise MsticpyException( - "please train the model first before using this method" - ) + raise MsticpyException("please train the model first before using this method") if self.session_type == SessionType.cmds_only: rare_tuples = [ @@ -609,9 +598,7 @@ def compute_rarest_windows( if use_geo_mean: self.rare_windows_geo[window_len] = [rare[0] for rare in rare_tuples] - self.rare_window_likelihoods_geo[window_len] = [ - rare[1] for rare in rare_tuples - ] + self.rare_window_likelihoods_geo[window_len] = [rare[1] for rare in rare_tuples] else: self.rare_windows[window_len] = [rare[0] for rare in rare_tuples] self.rare_window_likelihoods[window_len] = [rare[1] for rare in rare_tuples] diff --git a/msticpy/analysis/anomalous_sequence/sessionize.py b/msticpy/analysis/anomalous_sequence/sessionize.py index 2eb51af20..311d9ca35 100644 --- a/msticpy/analysis/anomalous_sequence/sessionize.py +++ b/msticpy/analysis/anomalous_sequence/sessionize.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Module for creating sessions out of raw data.""" + from __future__ import annotations import numpy as np diff --git a/msticpy/analysis/anomalous_sequence/utils/cmds_only.py b/msticpy/analysis/anomalous_sequence/utils/cmds_only.py index d1f88f792..7a828c49c 100644 --- a/msticpy/analysis/anomalous_sequence/utils/cmds_only.py +++ b/msticpy/analysis/anomalous_sequence/utils/cmds_only.py @@ -7,7 +7,6 @@ import copy from collections import defaultdict -from typing import DefaultDict, List, Tuple, Union import numpy as np @@ -17,8 +16,8 @@ def compute_counts( # nosec - sessions: List[List[str]], start_token: str, end_token: str, unk_token: str -) -> Tuple[DefaultDict[str, int], DefaultDict[str, DefaultDict[str, int]]]: + sessions: list[list[str]], start_token: str, end_token: str, unk_token: str +) -> tuple[defaultdict[str, int], defaultdict[str, defaultdict[str, int]]]: """ Compute counts of individual commands and of sequences of two commands. @@ -46,12 +45,11 @@ def compute_counts( # nosec """ if not start_token != end_token != unk_token: raise MsticpyException( - "start_token, end_token, unk_tokens should all be set to something " - "different" + "start_token, end_token, unk_tokens should all be set to something different" ) - seq1_counts: DefaultDict[str, int] = defaultdict(lambda: 0) - seq2_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict( + seq1_counts: defaultdict[str, int] = defaultdict(lambda: 0) + seq2_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( lambda: defaultdict(lambda: 0) ) @@ -69,12 +67,12 @@ def compute_counts( # nosec def laplace_smooth_counts( - seq1_counts: DefaultDict[str, int], - seq2_counts: DefaultDict[str, DefaultDict[str, int]], + seq1_counts: defaultdict[str, int], + seq2_counts: defaultdict[str, defaultdict[str, int]], start_token: str, end_token: str, unk_token: str, -) -> Tuple[StateMatrix, StateMatrix]: +) -> tuple[StateMatrix, StateMatrix]: """ Laplace smoothing is applied to the counts. @@ -121,9 +119,9 @@ def laplace_smooth_counts( # pylint: disable=too-many-arguments, too-many-branches def compute_likelihood_window( - window: List[str], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], + window: list[str], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, use_start_token: bool, use_end_token: bool, start_token: str = None, @@ -168,9 +166,7 @@ def compute_likelihood_window( if use_end_token: if end_token is None: - raise MsticpyException( - "end_token should not be None, when use_end_token is True" - ) + raise MsticpyException("end_token should not be None, when use_end_token is True") w_len = len(window) if w_len == 0: @@ -196,15 +192,15 @@ def compute_likelihood_window( # pylint: disable=too-many-locals, too-many-arguments, too-many-branches # pylint: disable=too-many-locals, too-many-branches def compute_likelihood_windows_in_session( - session: List[str], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], + session: list[str], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, window_len: int, use_start_end_tokens: bool, start_token: str = None, end_token: str = None, use_geo_mean: bool = False, -) -> List[float]: +) -> list[float]: """ Compute the likelihoods of a sliding window of length `window_len` in the session. @@ -278,15 +274,15 @@ def compute_likelihood_windows_in_session( # pylint: disable=too-many-arguments def rarest_window_session( - session: List[str], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], + session: list[str], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, window_len: int, use_start_end_tokens: bool, start_token: str, end_token: str, use_geo_mean: bool = False, -) -> Tuple[List[str], float]: +) -> tuple[list[str], float]: """ Find and compute likelihood of the rarest window in the session. diff --git a/msticpy/analysis/anomalous_sequence/utils/cmds_params_only.py b/msticpy/analysis/anomalous_sequence/utils/cmds_params_only.py index bf541fcce..e8428e90c 100644 --- a/msticpy/analysis/anomalous_sequence/utils/cmds_params_only.py +++ b/msticpy/analysis/anomalous_sequence/utils/cmds_params_only.py @@ -13,7 +13,6 @@ import copy from collections import defaultdict -from typing import DefaultDict, List, Tuple, Union import numpy as np @@ -27,12 +26,12 @@ # pylint: disable=too-many-locals, too-many-branches def compute_counts( # nosec - sessions: List[List[Cmd]], start_token: str, end_token: str -) -> Tuple[ - DefaultDict[str, int], - DefaultDict[str, DefaultDict[str, int]], - DefaultDict[str, int], - DefaultDict[str, DefaultDict[str, int]], + sessions: list[list[Cmd]], start_token: str, end_token: str +) -> tuple[ + defaultdict[str, int], + defaultdict[str, defaultdict[str, int]], + defaultdict[str, int], + defaultdict[str, defaultdict[str, int]], ]: """ Compute the training counts for the sessions. @@ -66,13 +65,13 @@ def compute_counts( # nosec param conditional on command counts """ - seq1_counts: DefaultDict[str, int] = defaultdict(lambda: 0) - seq2_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict( + seq1_counts: defaultdict[str, int] = defaultdict(lambda: 0) + seq2_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( lambda: defaultdict(lambda: 0) ) - param_counts: DefaultDict[str, int] = defaultdict(lambda: 0) - cmd_param_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict( + param_counts: defaultdict[str, int] = defaultdict(lambda: 0) + cmd_param_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( lambda: defaultdict(lambda: 0) ) @@ -93,10 +92,10 @@ def compute_counts( # nosec def laplace_smooth_counts( - seq1_counts: DefaultDict[str, int], - seq2_counts: DefaultDict[str, DefaultDict[str, int]], - param_counts: DefaultDict[str, int], - cmd_param_counts: DefaultDict[str, DefaultDict[str, int]], + seq1_counts: defaultdict[str, int], + seq2_counts: defaultdict[str, defaultdict[str, int]], + param_counts: defaultdict[str, int], + cmd_param_counts: defaultdict[str, defaultdict[str, int]], start_token: str, end_token: str, unk_token: str, @@ -136,7 +135,7 @@ def laplace_smooth_counts( param conditional on command counts """ - cmds: List[str] = list(seq1_counts.keys()) + [unk_token] + cmds: list[str] = list(seq1_counts.keys()) + [unk_token] # apply laplace smoothing for cmds seq1_counts_ls, seq2_counts_ls = laplace_smooth_cmd_counts( @@ -165,8 +164,8 @@ def laplace_smooth_counts( def compute_prob_setofparams_given_cmd( cmd: str, - params: Union[set, dict], - param_cond_cmd_probs: Union[StateMatrix, dict], + params: set | dict, + param_cond_cmd_probs: StateMatrix | dict, use_geo_mean: bool = True, ) -> float: """ @@ -223,10 +222,10 @@ def compute_prob_setofparams_given_cmd( # pylint: disable=too-many-locals, too-many-arguments, too-many-branches def compute_likelihood_window( - window: List[Cmd], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], - param_cond_cmd_probs: Union[StateMatrix, dict], + window: list[Cmd], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, + param_cond_cmd_probs: StateMatrix | dict, use_start_token: bool, use_end_token: bool, start_token: str = None, @@ -268,9 +267,7 @@ def compute_likelihood_window( """ if use_end_token: if end_token is None: - raise MsticpyException( - "end_token should not be None, when use_end_token is True" - ) + raise MsticpyException("end_token should not be None, when use_end_token is True") if use_start_token: if start_token is None: @@ -318,16 +315,16 @@ def compute_likelihood_window( # pylint: disable=too-many-locals, too-many-arguments, too-many-branches def compute_likelihood_windows_in_session( - session: List[Cmd], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], - param_cond_cmd_probs: Union[StateMatrix, dict], + session: list[Cmd], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, + param_cond_cmd_probs: StateMatrix | dict, window_len: int, use_start_end_tokens: bool, start_token: str = None, end_token: str = None, use_geo_mean: bool = False, -) -> List[float]: +) -> list[float]: """ Compute the likelihoods of a sliding window in the session. @@ -407,7 +404,7 @@ def compute_likelihood_windows_in_session( # pylint: disable=too-many-arguments def rarest_window_session( - session: List[Cmd], + session: list[Cmd], prior_probs: StateMatrix, trans_probs: StateMatrix, param_cond_cmd_probs: StateMatrix, @@ -416,7 +413,7 @@ def rarest_window_session( start_token: str, end_token: str, use_geo_mean=False, -) -> Tuple[List[Cmd], float]: +) -> tuple[list[Cmd], float]: """ Find and compute the likelihood of the rarest window of `window_len` in the session. diff --git a/msticpy/analysis/anomalous_sequence/utils/cmds_params_values.py b/msticpy/analysis/anomalous_sequence/utils/cmds_params_values.py index 6a76d53ed..619b4c373 100644 --- a/msticpy/analysis/anomalous_sequence/utils/cmds_params_values.py +++ b/msticpy/analysis/anomalous_sequence/utils/cmds_params_values.py @@ -13,7 +13,6 @@ import copy from collections import defaultdict -from typing import DefaultDict, List, Tuple, Union import numpy as np @@ -28,14 +27,14 @@ # pylint: disable=too-many-locals, too-many-branches def compute_counts( # noqa MC0001 # nosec - sessions: List[List[Cmd]], start_token: str, end_token: str -) -> Tuple[ - DefaultDict[str, int], - DefaultDict[str, DefaultDict[str, int]], - DefaultDict[str, int], - DefaultDict[str, DefaultDict[str, int]], - DefaultDict[str, int], - DefaultDict[str, DefaultDict[str, int]], + sessions: list[list[Cmd]], start_token: str, end_token: str +) -> tuple[ + defaultdict[str, int], + defaultdict[str, defaultdict[str, int]], + defaultdict[str, int], + defaultdict[str, defaultdict[str, int]], + defaultdict[str, int], + defaultdict[str, defaultdict[str, int]], ]: """ Compute the training counts for the sessions. @@ -82,18 +81,18 @@ def compute_counts( # noqa MC0001 # nosec value conditional on param counts """ - seq1_counts: DefaultDict[str, int] = defaultdict(lambda: 0) - seq2_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict( + seq1_counts: defaultdict[str, int] = defaultdict(lambda: 0) + seq2_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( lambda: defaultdict(lambda: 0) ) - param_counts: DefaultDict[str, int] = defaultdict(lambda: 0) - cmd_param_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict( + param_counts: defaultdict[str, int] = defaultdict(lambda: 0) + cmd_param_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( lambda: defaultdict(lambda: 0) ) - value_counts: DefaultDict[str, int] = defaultdict(lambda: 0) - param_value_counts: DefaultDict[str, DefaultDict[str, int]] = defaultdict( + value_counts: defaultdict[str, int] = defaultdict(lambda: 0) + param_value_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( lambda: defaultdict(lambda: 0) ) @@ -128,18 +127,16 @@ def compute_counts( # noqa MC0001 # nosec # pylint: disable=too-many-arguments def laplace_smooth_counts( - seq1_counts: DefaultDict[str, int], - seq2_counts: DefaultDict[str, DefaultDict[str, int]], - param_counts: DefaultDict[str, int], - cmd_param_counts: DefaultDict[str, DefaultDict[str, int]], - value_counts: DefaultDict[str, int], - param_value_counts: DefaultDict[str, DefaultDict[str, int]], + seq1_counts: defaultdict[str, int], + seq2_counts: defaultdict[str, defaultdict[str, int]], + param_counts: defaultdict[str, int], + cmd_param_counts: defaultdict[str, defaultdict[str, int]], + value_counts: defaultdict[str, int], + param_value_counts: defaultdict[str, defaultdict[str, int]], start_token: str, end_token: str, unk_token: str, -) -> Tuple[ - StateMatrix, StateMatrix, StateMatrix, StateMatrix, StateMatrix, StateMatrix -]: +) -> tuple[StateMatrix, StateMatrix, StateMatrix, StateMatrix, StateMatrix, StateMatrix]: """ Laplace smoothing is applied to the counts. @@ -181,7 +178,7 @@ def laplace_smooth_counts( value conditional on param counts """ - cmds: List[str] = list(seq1_counts.keys()) + [unk_token] + cmds: list[str] = list(seq1_counts.keys()) + [unk_token] # apply laplace smoothing to the cmds seq1_counts_ls, seq2_counts_ls = laplace_smooth_cmd_counts( @@ -192,7 +189,7 @@ def laplace_smooth_counts( unk_token=unk_token, ) - params: List[str] = list(param_counts.keys()) + [unk_token] + params: list[str] = list(param_counts.keys()) + [unk_token] # apply laplace smoothing to the params param_counts_ls, cmd_param_counts_ls = laplace_smooth_param_counts( @@ -215,9 +212,7 @@ def laplace_smooth_counts( param_counts_sm = StateMatrix(states=param_counts_ls, unk_token=unk_token) cmd_param_counts_sm = StateMatrix(states=cmd_param_counts_ls, unk_token=unk_token) value_counts_sm = StateMatrix(states=value_counts_ls, unk_token=unk_token) - param_value_counts_sm = StateMatrix( - states=param_value_counts_ls, unk_token=unk_token - ) + param_value_counts_sm = StateMatrix(states=param_value_counts_ls, unk_token=unk_token) return ( seq1_counts_sm, @@ -230,7 +225,7 @@ def laplace_smooth_counts( def get_params_to_model_values( - param_counts: Union[StateMatrix, dict], param_value_counts: Union[StateMatrix, dict] + param_counts: StateMatrix | dict, param_value_counts: StateMatrix | dict ) -> set: """ Determine using heuristics which params take categoricals vs arbitrary strings. @@ -256,9 +251,7 @@ def get_params_to_model_values( ] modellable_params = [ - param[0] - for param in param_stats - if param[1] <= 20 <= param[2] and param[3] <= 10 + param[0] for param in param_stats if param[1] <= 20 <= param[2] and param[3] <= 10 ] return set(modellable_params) @@ -267,10 +260,10 @@ def get_params_to_model_values( # pylint: disable=too-many-arguments, too-many-branches def compute_prob_setofparams_given_cmd( cmd: str, - params_with_vals: Union[dict, set], - param_cond_cmd_probs: Union[StateMatrix, dict], - value_cond_param_probs: Union[StateMatrix, dict], - modellable_params: Union[set, list], + params_with_vals: dict | set, + param_cond_cmd_probs: StateMatrix | dict, + value_cond_param_probs: StateMatrix | dict, + modellable_params: set | list, use_geo_mean: bool = True, ) -> float: """ @@ -337,11 +330,11 @@ def compute_prob_setofparams_given_cmd( # pylint: disable=too-many-locals, too-many-arguments, too-many-branches def compute_likelihood_window( - window: List[Cmd], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], - param_cond_cmd_probs: Union[StateMatrix, dict], - value_cond_param_probs: Union[StateMatrix, dict], + window: list[Cmd], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, + param_cond_cmd_probs: StateMatrix | dict, + value_cond_param_probs: StateMatrix | dict, modellable_params: set, use_start_token: bool, use_end_token: bool, @@ -397,9 +390,7 @@ def compute_likelihood_window( ) if use_end_token: if end_token is None: - raise MsticpyException( - "end_token should not be None, when use_end_token is True" - ) + raise MsticpyException("end_token should not be None, when use_end_token is True") w_len = len(window) if w_len == 0: @@ -445,18 +436,18 @@ def compute_likelihood_window( # pylint: disable=too-many-locals, too-many-arguments def compute_likelihood_windows_in_session( - session: List[Cmd], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], - param_cond_cmd_probs: Union[StateMatrix, dict], - value_cond_param_probs: Union[StateMatrix, dict], + session: list[Cmd], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, + param_cond_cmd_probs: StateMatrix | dict, + value_cond_param_probs: StateMatrix | dict, modellable_params: set, window_len: int, use_start_end_tokens: bool, start_token: str = None, end_token: str = None, use_geo_mean: bool = False, -) -> List[float]: +) -> list[float]: """ Compute the likelihoods of a sliding window of `window_len` in the session. @@ -543,18 +534,18 @@ def compute_likelihood_windows_in_session( # pylint: disable=too-many-arguments def rarest_window_session( - session: List[Cmd], - prior_probs: Union[StateMatrix, dict], - trans_probs: Union[StateMatrix, dict], - param_cond_cmd_probs: Union[StateMatrix, dict], - value_cond_param_probs: Union[StateMatrix, dict], + session: list[Cmd], + prior_probs: StateMatrix | dict, + trans_probs: StateMatrix | dict, + param_cond_cmd_probs: StateMatrix | dict, + value_cond_param_probs: StateMatrix | dict, modellable_params: set, window_len: int, use_start_end_tokens: bool, start_token: str, end_token: str, use_geo_mean: bool = False, -) -> Tuple[List[Cmd], float]: +) -> tuple[list[Cmd], float]: """ Find and compute likelihood of the rarest window of `window_len` in the session. diff --git a/msticpy/analysis/anomalous_sequence/utils/data_structures.py b/msticpy/analysis/anomalous_sequence/utils/data_structures.py index a0fbc159e..665c9729b 100644 --- a/msticpy/analysis/anomalous_sequence/utils/data_structures.py +++ b/msticpy/analysis/anomalous_sequence/utils/data_structures.py @@ -6,7 +6,6 @@ """Useful helper data structure classes for modelling sessions.""" from collections import defaultdict -from typing import Union from ....common.exceptions import MsticpyException @@ -14,7 +13,7 @@ class StateMatrix(dict): """Class for storing trained counts/probabilities.""" - def __init__(self, states: Union[dict, defaultdict], unk_token: str): + def __init__(self, states: dict | defaultdict, unk_token: str): """ Containr for dict of counts/probs or dict of dicts of cond counts/probs. @@ -76,7 +75,7 @@ def __getitem__(self, item): class Cmd: """Class to store commands with accompanying params (and optionally values).""" - def __init__(self, name: str, params: Union[set, dict]): + def __init__(self, name: str, params: set | dict): """ Instantiate the Cmd class. diff --git a/msticpy/analysis/anomalous_sequence/utils/laplace_smooth.py b/msticpy/analysis/anomalous_sequence/utils/laplace_smooth.py index 4387583b5..b4fc8c471 100644 --- a/msticpy/analysis/anomalous_sequence/utils/laplace_smooth.py +++ b/msticpy/analysis/anomalous_sequence/utils/laplace_smooth.py @@ -6,16 +6,16 @@ """Helper module for laplace smoothing counts.""" import copy -from typing import DefaultDict, List, Tuple +from collections import defaultdict def laplace_smooth_cmd_counts( - seq1_counts: DefaultDict[str, int], - seq2_counts: DefaultDict[str, DefaultDict[str, int]], + seq1_counts: defaultdict[str, int], + seq2_counts: defaultdict[str, defaultdict[str, int]], start_token: str, end_token: str, unk_token: str, -) -> Tuple[DefaultDict[str, int], DefaultDict[str, DefaultDict[str, int]]]: +) -> tuple[defaultdict[str, int], defaultdict[str, defaultdict[str, int]]]: """ Apply laplace smoothing to the input counts for the cmds. @@ -45,7 +45,7 @@ def laplace_smooth_cmd_counts( seq1_counts_ls = copy.deepcopy(seq1_counts) seq2_counts_ls = copy.deepcopy(seq2_counts) - cmds: List[str] = list(seq1_counts_ls.keys()) + [unk_token] + cmds: list[str] = list(seq1_counts_ls.keys()) + [unk_token] for cmd1 in cmds: for cmd2 in cmds: if cmd1 != end_token and cmd2 != start_token: @@ -57,11 +57,11 @@ def laplace_smooth_cmd_counts( def laplace_smooth_param_counts( - cmds: List[str], - param_counts: DefaultDict[str, int], - cmd_param_counts: DefaultDict[str, DefaultDict[str, int]], + cmds: list[str], + param_counts: defaultdict[str, int], + cmd_param_counts: defaultdict[str, defaultdict[str, int]], unk_token: str, -) -> Tuple[DefaultDict[str, int], DefaultDict[str, DefaultDict[str, int]]]: +) -> tuple[defaultdict[str, int], defaultdict[str, defaultdict[str, int]]]: """ Apply laplace smoothing to the input counts for the params. @@ -89,7 +89,7 @@ def laplace_smooth_param_counts( param_counts_ls = copy.deepcopy(param_counts) cmd_param_counts_ls = copy.deepcopy(cmd_param_counts) - params: List[str] = list(param_counts.keys()) + [unk_token] + params: list[str] = list(param_counts.keys()) + [unk_token] for cmd in cmds: for param in params: if param in cmd_param_counts_ls[cmd] or param == unk_token: @@ -100,11 +100,11 @@ def laplace_smooth_param_counts( def laplace_smooth_value_counts( - params: List[str], - value_counts: DefaultDict[str, int], - param_value_counts: DefaultDict[str, DefaultDict[str, int]], + params: list[str], + value_counts: defaultdict[str, int], + param_value_counts: defaultdict[str, defaultdict[str, int]], unk_token: str, -) -> Tuple[DefaultDict[str, int], DefaultDict[str, DefaultDict[str, int]]]: +) -> tuple[defaultdict[str, int], defaultdict[str, defaultdict[str, int]]]: """ Apply laplace smoothing to the input counts for the values. @@ -132,7 +132,7 @@ def laplace_smooth_value_counts( value_counts_ls = copy.deepcopy(value_counts) param_value_counts_ls = copy.deepcopy(param_value_counts) - values: List[str] = list(value_counts_ls.keys()) + [unk_token] + values: list[str] = list(value_counts_ls.keys()) + [unk_token] for param in params: for value in values: if value in param_value_counts_ls[param] or value == unk_token: diff --git a/msticpy/analysis/anomalous_sequence/utils/probabilities.py b/msticpy/analysis/anomalous_sequence/utils/probabilities.py index b24c3135a..ac7f9c282 100644 --- a/msticpy/analysis/anomalous_sequence/utils/probabilities.py +++ b/msticpy/analysis/anomalous_sequence/utils/probabilities.py @@ -6,16 +6,15 @@ """Helper module for computing training probabilities when modelling sessions.""" from collections import defaultdict -from typing import DefaultDict, Tuple, Union from ..utils.data_structures import StateMatrix def compute_cmds_probs( # nosec - seq1_counts: Union[StateMatrix, dict], - seq2_counts: Union[StateMatrix, dict], + seq1_counts: StateMatrix | dict, + seq2_counts: StateMatrix | dict, unk_token: str, -) -> Tuple[StateMatrix, StateMatrix]: +) -> tuple[StateMatrix, StateMatrix]: """ Compute command related probabilities. @@ -40,8 +39,8 @@ def compute_cmds_probs( # nosec """ total_cmds = sum(seq1_counts.values()) - prior_probs: DefaultDict[str, float] = defaultdict(lambda: 0) - trans_probs: DefaultDict[str, DefaultDict[str, float]] = defaultdict( + prior_probs: defaultdict[str, float] = defaultdict(lambda: 0) + trans_probs: defaultdict[str, defaultdict[str, float]] = defaultdict( lambda: defaultdict(lambda: 0) ) @@ -52,9 +51,7 @@ def compute_cmds_probs( # nosec # compute trans probs for prev, currents in seq2_counts.items(): for current in currents: - trans_probs[prev][current] = seq2_counts[prev][current] / sum( - seq2_counts[prev].values() - ) + trans_probs[prev][current] = currents[current] / sum(currents.values()) prior_probs_sm = StateMatrix(states=prior_probs, unk_token=unk_token) trans_probs_sm = StateMatrix(states=trans_probs, unk_token=unk_token) @@ -63,11 +60,11 @@ def compute_cmds_probs( # nosec def compute_params_probs( # nosec - param_counts: Union[StateMatrix, dict], - cmd_param_counts: Union[StateMatrix, dict], - seq1_counts: Union[StateMatrix, dict], + param_counts: StateMatrix | dict, + cmd_param_counts: StateMatrix | dict, + seq1_counts: StateMatrix | dict, unk_token: str, -) -> Tuple[StateMatrix, StateMatrix]: +) -> tuple[StateMatrix, StateMatrix]: """ Compute param related probabilities. @@ -108,8 +105,8 @@ def compute_params_probs( # nosec param conditional on command probabilities """ - param_probs: DefaultDict[str, float] = defaultdict(lambda: 0) - param_cond_cmd_probs: DefaultDict[str, DefaultDict[str, float]] = defaultdict( + param_probs: defaultdict[str, float] = defaultdict(lambda: 0) + param_cond_cmd_probs: defaultdict[str, defaultdict[str, float]] = defaultdict( lambda: defaultdict(lambda: 0) ) @@ -123,18 +120,16 @@ def compute_params_probs( # nosec param_probs[param] = count / tot_cmd param_probs_sm = StateMatrix(states=param_probs, unk_token=unk_token) - param_cond_cmd_probs_sm = StateMatrix( - states=param_cond_cmd_probs, unk_token=unk_token - ) + param_cond_cmd_probs_sm = StateMatrix(states=param_cond_cmd_probs, unk_token=unk_token) return param_probs_sm, param_cond_cmd_probs_sm def compute_values_probs( # nosec - value_counts: Union[StateMatrix, dict], - param_value_counts: Union[StateMatrix, dict], + value_counts: StateMatrix | dict, + param_value_counts: StateMatrix | dict, unk_token: str, -) -> Tuple[StateMatrix, StateMatrix]: +) -> tuple[StateMatrix, StateMatrix]: """ Compute value related probabilities. @@ -164,8 +159,8 @@ def compute_values_probs( # nosec value conditional on param probabilities """ - value_probs: DefaultDict[str, float] = defaultdict(lambda: 0) - value_cond_param_probs: DefaultDict[str, DefaultDict[str, float]] = defaultdict( + value_probs: defaultdict[str, float] = defaultdict(lambda: 0) + value_cond_param_probs: defaultdict[str, defaultdict[str, float]] = defaultdict( lambda: defaultdict(lambda: 0) ) @@ -179,8 +174,6 @@ def compute_values_probs( # nosec value_probs[value] = count / tot_val value_probs_sm = StateMatrix(states=value_probs, unk_token=unk_token) - value_cond_param_probs_sm = StateMatrix( - states=value_cond_param_probs, unk_token=unk_token - ) + value_cond_param_probs_sm = StateMatrix(states=value_cond_param_probs, unk_token=unk_token) return value_probs_sm, value_cond_param_probs_sm diff --git a/msticpy/analysis/cluster_auditd.py b/msticpy/analysis/cluster_auditd.py index df696359e..847f615f8 100644 --- a/msticpy/analysis/cluster_auditd.py +++ b/msticpy/analysis/cluster_auditd.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Auditd cluster function.""" + import pandas as pd from .._version import VERSION diff --git a/msticpy/analysis/code_cleanup.py b/msticpy/analysis/code_cleanup.py index d3255fa1b..a1e96ab60 100644 --- a/msticpy/analysis/code_cleanup.py +++ b/msticpy/analysis/code_cleanup.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Code cleanup functions to re-format obfuscated code.""" + import re from .._version import VERSION diff --git a/msticpy/analysis/eventcluster.py b/msticpy/analysis/eventcluster.py index c2aeab913..1f207fab5 100644 --- a/msticpy/analysis/eventcluster.py +++ b/msticpy/analysis/eventcluster.py @@ -31,11 +31,12 @@ commandline and process path. """ + import re from binascii import crc32 from functools import lru_cache from math import floor, log10 -from typing import Any, List, Tuple, Union +from typing import Any import numpy as np import pandas as pd @@ -65,14 +66,14 @@ @export def dbcluster_events( data: Any, - cluster_columns: List[Any] = None, + cluster_columns: list[Any] = None, verbose: bool = False, normalize: bool = True, time_column: str = "TimeCreatedUtc", max_cluster_distance: float = 0.01, min_cluster_samples: int = 2, **kwargs, -) -> Tuple[pd.DataFrame, DBSCAN, np.ndarray]: +) -> tuple[pd.DataFrame, DBSCAN, np.ndarray]: """ Cluster data set according to cluster_columns features. @@ -126,9 +127,7 @@ def dbcluster_events( ) # Create DBSCAN cluster object - db_cluster = DBSCAN( - eps=max_cluster_distance, min_samples=min_cluster_samples, **kwargs - ) + db_cluster = DBSCAN(eps=max_cluster_distance, min_samples=min_cluster_samples, **kwargs) # Normalize the data (most clustering algorithms don't do well with # unnormalized data) @@ -147,9 +146,7 @@ def dbcluster_events( ) print("Individual cluster sizes: ", ", ".join(str(c) for c in counts)) - clustered_events = _merge_clustered_items( - cluster_set, labels, data, time_column, counts - ) + clustered_events = _merge_clustered_items(cluster_set, labels, data, time_column, counts) if verbose: print("Cluster output rows: ", len(clustered_events)) @@ -160,7 +157,7 @@ def dbcluster_events( def _merge_clustered_items( cluster_set: np.ndarray, labels: np.ndarray, - data: Union[pd.DataFrame, np.ndarray], + data: pd.DataFrame | np.ndarray, time_column: str, counts: np.ndarray, ) -> pd.DataFrame: @@ -320,9 +317,7 @@ def add_process_features( return output_df -def _add_processname_features( - output_df: pd.DataFrame, force: bool, path_separator: str -): +def _add_processname_features(output_df: pd.DataFrame, force: bool, path_separator: str): """ Add process name default features. @@ -349,9 +344,7 @@ def _add_processname_features( lambda x: log10(x.pathScore) if x.pathScore else 0, axis=1 ) if "pathHash" not in output_df or force: - output_df["pathHash"] = output_df.apply( - lambda x: crc32_hash(x.NewProcessName), axis=1 - ) + output_df["pathHash"] = output_df.apply(lambda x: crc32_hash(x.NewProcessName), axis=1) def _add_commandline_features(output_df: pd.DataFrame, force: bool): @@ -367,9 +360,7 @@ def _add_commandline_features(output_df: pd.DataFrame, force: bool): """ if "commandlineLen" not in output_df or force: - output_df["commandlineLen"] = output_df.apply( - lambda x: len(x.CommandLine), axis=1 - ) + output_df["commandlineLen"] = output_df.apply(lambda x: len(x.CommandLine), axis=1) if "commandlineLogLen" not in output_df or force: output_df["commandlineLogLen"] = output_df.apply( lambda x: log10(x.commandlineLen) if x.commandlineLen else 0, axis=1 @@ -630,13 +621,13 @@ def crc32_hash_df(data: pd.DataFrame, column: str) -> pd.Series: # pylint: disable=too-many-arguments, too-many-statements -@export # noqa: C901, MC0001 -def plot_cluster( # noqa: C901, MC0001 +@export # noqa: C901 +def plot_cluster( # noqa: C901 db_cluster: DBSCAN, data: pd.DataFrame, x_predict: np.ndarray, plot_label: str = None, - plot_features: Tuple[int, int] = (0, 1), + plot_features: tuple[int, int] = (0, 1), verbose: bool = False, cut_off: int = 3, xlabel: str = None, @@ -701,14 +692,10 @@ def plot_cluster( # noqa: C901, MC0001 # print("Silhouette Coefficient: %0.3f" # % metrics.silhouette_score(x_predict, labels)) - if ( - not isinstance(data, pd.DataFrame) - or plot_label is not None - and plot_label not in data - ): + if not isinstance(data, pd.DataFrame) or plot_label is not None and plot_label not in data: plot_label = None p_label = None - for cluster_id, color in zip(unique_labels, colors): + for cluster_id, color in zip(unique_labels, colors, strict=False): if cluster_id == -1: # Black used for noise. color = [0, 0, 0, 1] diff --git a/msticpy/analysis/observationlist.py b/msticpy/analysis/observationlist.py index d7104d4ee..9729d7478 100644 --- a/msticpy/analysis/observationlist.py +++ b/msticpy/analysis/observationlist.py @@ -4,9 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Observation summary collector.""" + from collections import OrderedDict +from collections.abc import Iterator, Mapping from datetime import datetime -from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple +from typing import Any import attr import pandas as pd @@ -57,20 +59,20 @@ class Observation: caption: str data: Any - description: Optional[str] = None - data_type: Optional[str] = None - link: Optional[str] = None + description: str | None = None + data_type: str | None = None + link: str | None = None score: int = 0 - tags: List[str] = Factory(list) - additional_properties: Dict[str, Any] = Factory(dict) - timestamp: Optional[datetime] = None - time_span: Optional[TimeSpan] = None - time_column: Optional[str] = None - filter: Optional[str] = None - schema: Optional[str] = None + tags: list[str] = Factory(list) + additional_properties: dict[str, Any] = Factory(dict) + timestamp: datetime | None = None + time_span: TimeSpan | None = None + time_column: str | None = None + filter: str | None = None + schema: str | None = None @classmethod - def required_fields(cls) -> List[str]: + def required_fields(cls) -> list[str]: """ Return required fields for Observation instance. @@ -83,7 +85,7 @@ def required_fields(cls) -> List[str]: return ["caption", "data"] @classmethod - def all_fields(cls) -> Set[str]: + def all_fields(cls) -> set[str]: """ Return all fields of Observation class. @@ -93,7 +95,7 @@ def all_fields(cls) -> Set[str]: Set of all field names. """ - return {field.name for field in attr.fields(cls)} # type: ignore[misc] + return {field.name for field in attr.fields(cls)} def display(self): """Display the observation.""" @@ -105,7 +107,7 @@ def display(self): if self.link: display(Markdown(f"[Go to details](#{self.link})")) if self.tags: - display(Markdown(f'tags: {", ".join(self.tags)}')) + display(Markdown(f"tags: {', '.join(self.tags)}")) display(self.filtered_data) if self.additional_properties: display(Markdown("### Additional Properties")) @@ -144,7 +146,7 @@ def __init__(self, observationlist: "Observations" = None): (the default is None) """ - self.observation_list: Dict[str, Observation] = OrderedDict() + self.observation_list: dict[str, Observation] = OrderedDict() if observationlist is not None: self.observation_list.update(observationlist.observations) @@ -152,7 +154,7 @@ def __getitem__(self, key: str) -> Observation: """Return the observation with a caption.""" return self.observation_list[key] - def __iter__(self) -> Iterator[Tuple[str, Observation]]: + def __iter__(self) -> Iterator[tuple[str, Observation]]: """Return iterator over observations.""" yield from self.observation_list.items() @@ -208,9 +210,7 @@ def add_observation(self, observation: Observation = None, **kwargs): ) core_fields = { - key: value - for key, value in kwargs.items() - if key in Observation.all_fields() + key: value for key, value in kwargs.items() if key in Observation.all_fields() } new_observation = Observation(**core_fields) addl_fields = { diff --git a/msticpy/analysis/outliers.py b/msticpy/analysis/outliers.py index 00a6996d9..18b56bc4a 100644 --- a/msticpy/analysis/outliers.py +++ b/msticpy/analysis/outliers.py @@ -227,7 +227,7 @@ def _select_train_samples(self, rows: int) -> np.ndarray: rng = np.random.RandomState(42) return rng.choice(rows, n_samples, replace=False) - def fit(self, x: np.ndarray) -> "RobustRandomCutForest": + def fit(self, x: np.ndarray) -> RobustRandomCutForest: """ Build the forest from training data. @@ -322,9 +322,7 @@ def decision_function(self, x: np.ndarray) -> np.ndarray: scores = np.sum(tree_scores, axis=0) / self.num_trees return scores - def _process_tree( - self, tree: rrcf.RCTree, x_sub: np.ndarray, batches: list - ) -> np.ndarray: + def _process_tree(self, tree: rrcf.RCTree, x_sub: np.ndarray, batches: list) -> np.ndarray: """ Process a single tree with batched operations. @@ -349,7 +347,7 @@ def _process_tree( temp_indices = np.arange(1000000 + start, 1000000 + end) # Insert batch - for idx, point in zip(temp_indices, batch): + for idx, point in zip(temp_indices, batch, strict=False): tree.insert_point(point, index=idx) # Calculate CoDisp @@ -476,7 +474,7 @@ def identify_outliers_rrcf( # pylint: disable=too-many-arguments, too-many-locals -def plot_outlier_results( +def plot_outlier_results( # noqa: PLR0915 clf: IsolationForest | RobustRandomCutForest, x: np.ndarray, x_predict: np.ndarray, @@ -516,9 +514,7 @@ def plot_outlier_results( np.c_[ xx.ravel(), yy.ravel(), - np.zeros( - (xx.ravel().shape[0], clf.n_features_in_ - len(feature_columns)) - ), + np.zeros((xx.ravel().shape[0], clf.n_features_in_ - len(feature_columns))), ] ) z = z.reshape(xx.shape) @@ -530,9 +526,7 @@ def plot_outlier_results( plt.contourf(xx, yy, z, cmap=plt.cm.Blues_r) # type: ignore b1 = plt.scatter(x[:, 0], x[:, 1], c="white", s=20, edgecolor="k") - b2 = plt.scatter( - x_predict[:, 0], x_predict[:, 1], c="green", s=40, edgecolor="k" - ) + b2 = plt.scatter(x_predict[:, 0], x_predict[:, 1], c="green", s=40, edgecolor="k") c = plt.scatter(x_outliers[:, 0], x_outliers[:, 1], c="red", marker="x", s=200) plt.axis("tight") @@ -584,11 +578,9 @@ def plot_outlier_results( z = z.reshape(xx.shape) # pylint: disable=no-member - axes[i, j].contourf(xx, yy, z, cmap=plt.cm.Blues_r) # type: ignore[index,attr-defined] + axes[i, j].contourf(xx, yy, z, cmap=plt.cm.Blues_r) # type: ignore[attr-defined, index] - b1 = axes[i, j].scatter( # type: ignore[index] - x[:, j], x[:, i], c="white", edgecolor="k" - ) + b1 = axes[i, j].scatter(x[:, j], x[:, i], c="white", edgecolor="k") # type: ignore[index] b2 = axes[i, j].scatter( # type: ignore[index] x_predict[:, j], x_predict[:, i], c="green", edgecolor="k" ) @@ -648,7 +640,7 @@ def remove_common_items(data: pd.DataFrame, columns: list[str]) -> pd.DataFrame: # pylint: disable=cell-var-from-loop for col in columns: filtered_df = filtered_df.filter( - lambda x: (x[col].std() == 0 and x[col].count() > 10) # type: ignore + lambda x, col=col: (x[col].std() == 0 and x[col].count() > 10) ) return filtered_df diff --git a/msticpy/analysis/polling_detection.py b/msticpy/analysis/polling_detection.py index 87cb73dd1..c6842593a 100644 --- a/msticpy/analysis/polling_detection.py +++ b/msticpy/analysis/polling_detection.py @@ -13,6 +13,7 @@ There is currently only one technique available for filtering polling data which is the class PeriodogramPollingDetector. """ + from __future__ import annotations from collections import Counter @@ -34,9 +35,7 @@ _PD_VERSION = Version(pd.__version__) -GROUP_APPLY_PARAMS = ( - {"include_groups": False} if Version("2.2.1") <= _PD_VERSION else {} -) +GROUP_APPLY_PARAMS = {"include_groups": False} if Version("2.2.1") <= _PD_VERSION else {} POWER_SPECTRAL_DENSITY_THRESHOLD: int = 700 diff --git a/msticpy/analysis/syslog_utils.py b/msticpy/analysis/syslog_utils.py index b0d34c60e..fb09e16e3 100644 --- a/msticpy/analysis/syslog_utils.py +++ b/msticpy/analysis/syslog_utils.py @@ -12,8 +12,9 @@ auditd is not available. """ + import datetime as dt -from typing import Any, Dict +from typing import Any import ipywidgets as widgets import pandas as pd @@ -58,7 +59,7 @@ def create_host_record( Details of the host data collected """ - host_entity = Host(src_event=syslog_df.iloc[0]) # type: ignore + host_entity = Host(src_event=syslog_df.iloc[0]) # Produce list of processes on the host that are not # part of a 'standard' linux distro _apps = syslog_df["ProcessName"].unique().tolist() @@ -91,16 +92,16 @@ def create_host_record( host_entity.ComputerEnvironment = host_hb["ComputerEnvironment"] # type: ignore host_entity.OmsSolutions = [ # type: ignore sol.strip() for sol in host_hb["Solutions"].split(",") - ] # type: ignore + ] host_entity.Applications = applications # type: ignore host_entity.VMUUID = host_hb["VMUUID"] # type: ignore ip_entity = IpAddress() ip_entity.Address = host_hb["ComputerIP"] geoloc_entity = GeoLocation() - geoloc_entity.CountryOrRegionName = host_hb["RemoteIPCountry"] # type: ignore - geoloc_entity.Longitude = host_hb["RemoteIPLongitude"] # type: ignore - geoloc_entity.Latitude = host_hb["RemoteIPLatitude"] # type: ignore - ip_entity.Location = geoloc_entity # type: ignore + geoloc_entity.CountryOrRegionName = host_hb["RemoteIPCountry"] + geoloc_entity.Longitude = host_hb["RemoteIPLongitude"] + geoloc_entity.Latitude = host_hb["RemoteIPLatitude"] + ip_entity.Location = geoloc_entity host_entity.IPAddress = ip_entity # type: ignore # If Azure network data present add this to host record @@ -150,20 +151,12 @@ def cluster_syslog_logons_df(logon_events: pd.DataFrame) -> pd.DataFrame: ses_closed = 0 # Extract logon session opened and logon session closed data. logons_opened = ( - ( - logon_events[ - logon_events["SyslogMessage"].str.contains("pam_unix.+session opened") - ] - ) + (logon_events[logon_events["SyslogMessage"].str.contains("pam_unix.+session opened")]) .set_index("TimeGenerated") .sort_index(ascending=True) ) logons_closed = ( - ( - logon_events[ - logon_events["SyslogMessage"].str.contains("pam_unix.+session closed") - ] - ) + (logon_events[logon_events["SyslogMessage"].str.contains("pam_unix.+session closed")]) .set_index("TimeGenerated") .sort_index(ascending=True) ) @@ -171,9 +164,7 @@ def cluster_syslog_logons_df(logon_events: pd.DataFrame) -> pd.DataFrame: raise MsticpyException("There are no logon sessions in the supplied data set") # For each session identify the likely start and end times - while ses_opened < len(logons_opened.index) and ses_closed < len( - logons_closed.index - ): + while ses_opened < len(logons_opened.index) and ses_closed < len(logons_closed.index): ses_start = (logons_opened.iloc[ses_opened]).name ses_end = (logons_closed.iloc[ses_closed]).name # If we can identify a user for the session add this to the details @@ -186,7 +177,7 @@ def cluster_syslog_logons_df(logon_events: pd.DataFrame) -> pd.DataFrame: if ses_start <= ses_close_time and ses_opened != 0: ses_opened += 1 continue - if ses_end < ses_start: # type: ignore + if ses_end < ses_start: ses_closed += 1 continue users.append(user) @@ -231,8 +222,8 @@ def risky_sudo_sessions( # Depending on whether we have risky or suspicious acitons or both # identify sessions which these actions occur in - risky_act_sessions: Dict[str, Any] = {} - susp_act_sessions: Dict[str, Any] = {} + risky_act_sessions: dict[str, Any] = {} + susp_act_sessions: dict[str, Any] = {} if risky_actions is not None: risky_act_sessions = _find_risky_sudo_session( risky_actions=risky_actions, sudo_sessions=sessions diff --git a/msticpy/analysis/timeseries.py b/msticpy/analysis/timeseries.py index e2b988b7c..b9ef424f9 100644 --- a/msticpy/analysis/timeseries.py +++ b/msticpy/analysis/timeseries.py @@ -4,9 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Module for timeseries analysis functions.""" + import inspect from datetime import datetime -from typing import Dict, List, Optional import pandas as pd @@ -290,9 +290,7 @@ def ts_anomalies_stl(data: pd.DataFrame, **kwargs) -> pd.DataFrame: # this column does not contain seasonal/trend components result["score"] = stats.zscore(result["residual"]) # create spikes(1) and dips(-1) based on threshold and seasonal columns - result.loc[ - (result["score"] > score_threshold) & (result["seasonal"] > 0), "anomalies" - ] = 1 + result.loc[(result["score"] > score_threshold) & (result["seasonal"] > 0), "anomalies"] = 1 result.loc[ (result["score"] > score_threshold) & (result["seasonal"] < 0), "anomalies" ] = -1 @@ -313,7 +311,7 @@ def extract_anomaly_periods( period: str = "1h", pos_only: bool = True, anomalies_column: str = "anomalies", -) -> Dict[datetime, datetime]: +) -> dict[datetime, datetime]: """ Return dictionary of anomaly periods, merging adjacent ones. @@ -360,12 +358,10 @@ def extract_anomaly_periods( if not end_period: # If we're not already in an anomaly period # create start/end for a new one - start_period = time - pd.Timedelta(period) # type: ignore - end_period = time + pd.Timedelta(period) # type: ignore + start_period = time - pd.Timedelta(period) + end_period = time + pd.Timedelta(period) periods[start_period] = end_period - elif (time - end_period) <= pd.Timedelta( - period - ) * 2 and start_period is not None: + elif (time - end_period) <= pd.Timedelta(period) * 2 and start_period is not None: # if the current time is less than 2x the period away # from our current end_period time, update the end_time periods[start_period] = time + pd.Timedelta(period) @@ -383,7 +379,7 @@ def find_anomaly_periods( period: str = "1h", pos_only: bool = True, anomalies_column: str = "anomalies", -) -> List[TimeSpan]: +) -> list[TimeSpan]: """ Return list of anomaly period as TimeSpans. @@ -421,7 +417,7 @@ def find_anomaly_periods( ] -def create_time_period_kqlfilter(periods: Dict[datetime, datetime]) -> str: +def create_time_period_kqlfilter(periods: dict[datetime, datetime]) -> str: """ Return KQL time filter expression from anomaly periods. @@ -448,7 +444,7 @@ def create_time_period_kqlfilter(periods: Dict[datetime, datetime]) -> str: def set_new_anomaly_threshold( data: pd.DataFrame, threshold: float, - threshold_low: Optional[float] = None, + threshold_low: float | None = None, anomalies_column: str = "anomalies", ) -> pd.DataFrame: """ diff --git a/msticpy/auth/azure_auth.py b/msticpy/auth/azure_auth.py index 972c06dc3..0447b625f 100644 --- a/msticpy/auth/azure_auth.py +++ b/msticpy/auth/azure_auth.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Azure authentication handling.""" + from __future__ import annotations import os @@ -115,7 +116,7 @@ def az_connect( ) sub_client = SubscriptionClient( credential=credentials.modern, - base_url=az_cloud_config.resource_manager, # type: ignore + base_url=az_cloud_config.resource_manager, credential_scopes=[az_cloud_config.token_uri], ) if not sub_client: @@ -204,7 +205,7 @@ def fallback_devicecode_creds( title="Azure authentication error", ) - return AzCredentials(legacy_creds, ChainedTokenCredential(creds)) # type: ignore[arg-type] + return AzCredentials(legacy_creds, ChainedTokenCredential(creds)) def get_default_resource_name(resource_uri: str) -> str: diff --git a/msticpy/auth/azure_auth_core.py b/msticpy/auth/azure_auth_core.py index 847e6e3cd..1f48022fe 100644 --- a/msticpy/auth/azure_auth_core.py +++ b/msticpy/auth/azure_auth_core.py @@ -10,10 +10,11 @@ import logging import os import sys +from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass from datetime import datetime from enum import Enum -from typing import Any, Callable, ClassVar, Iterator +from typing import Any, ClassVar from azure.common.credentials import get_cli_profile from azure.core.credentials import TokenCredential @@ -45,9 +46,7 @@ logger: logging.Logger = logging.getLogger(__name__) -_HELP_URI = ( - "https://msticpy.readthedocs.io/en/latest/getting_started/AzureAuthentication.html" -) +_HELP_URI = "https://msticpy.readthedocs.io/en/latest/getting_started/AzureAuthentication.html" @dataclass @@ -71,9 +70,7 @@ def __getitem__(self, item) -> Any: class AzureCredEnvNames: """Enumeration of Azure environment credential names.""" - AZURE_CLIENT_ID: ClassVar[str] = ( - "AZURE_CLIENT_ID" # The app ID for the service principal - ) + AZURE_CLIENT_ID: ClassVar[str] = "AZURE_CLIENT_ID" # The app ID for the service principal AZURE_TENANT_ID: ClassVar[str] = ( "AZURE_TENANT_ID" # The service principal's Azure AD tenant ID ) @@ -89,15 +86,11 @@ class AzureCredEnvNames: AZURE_CLIENT_CERTIFICATE_PATH: ClassVar[str] = "AZURE_CLIENT_CERTIFICATE_PATH" # (Optional) The password protecting the certificate file # (for PFX (PKCS12) certificates). - AZURE_CLIENT_CERTIFICATE_PASSWORD: ClassVar[str] = ( - "AZURE_CLIENT_CERTIFICATE_PASSWORD" # nosec # noqa - ) + AZURE_CLIENT_CERTIFICATE_PASSWORD: ClassVar[str] = "AZURE_CLIENT_CERTIFICATE_PASSWORD" # nosec # noqa # (Optional) Specifies whether an authentication request will include an x5c # header to support subject name / issuer based authentication. # When set to `true` or `1`, authentication requests include the x5c header. - AZURE_CLIENT_SEND_CERTIFICATE_CHAIN: ClassVar[str] = ( - "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN" - ) + AZURE_CLIENT_SEND_CERTIFICATE_CHAIN: ClassVar[str] = "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN" # Username and password: AZURE_USERNAME: ClassVar[str] = ( @@ -204,9 +197,7 @@ def _build_msi_client( return cred except ClientAuthenticationError: # If we fail again, just create with no params - logger.info( - "Managed Identity credential failed auth - retrying with no params" - ) + logger.info("Managed Identity credential failed auth - retrying with no params") return ManagedIdentityCredential() @@ -270,15 +261,13 @@ def _build_certificate_client( ) -> CertificateCredential | None: """Build a credential from Certificate.""" if not client_id: - logger.info( - "'certificate' credential requested but client_id param not supplied" - ) + logger.info("'certificate' credential requested but client_id param not supplied") return None return CertificateCredential( authority=aad_uri, tenant_id=tenant_id, # type: ignore client_id=client_id, - **kwargs, # type: ignore + **kwargs, ) @@ -288,27 +277,25 @@ def _build_powershell_client(**kwargs) -> AzurePowerShellCredential: return AzurePowerShellCredential() -_CLIENTS: dict[str, Callable[..., TokenCredential | None]] = dict( - { - "env": _build_env_client, - "cli": _build_cli_client, - "msi": _build_msi_client, - "vscode": _build_vscode_client, - "powershell": _build_powershell_client, - "interactive": _build_interactive_client, - "interactive_browser": _build_interactive_client, - "devicecode": _build_device_code_client, - "device_code": _build_device_code_client, - "device": _build_device_code_client, - "environment": _build_env_client, - "managedidentity": _build_msi_client, - "managed_identity": _build_msi_client, - "clientsecret": _build_client_secret_client, - "client_secret": _build_client_secret_client, - "certificate": _build_certificate_client, - "cert": _build_certificate_client, - } -) +_CLIENTS: dict[str, Callable[..., TokenCredential | None]] = { + "env": _build_env_client, + "cli": _build_cli_client, + "msi": _build_msi_client, + "vscode": _build_vscode_client, + "powershell": _build_powershell_client, + "interactive": _build_interactive_client, + "interactive_browser": _build_interactive_client, + "devicecode": _build_device_code_client, + "device_code": _build_device_code_client, + "device": _build_device_code_client, + "environment": _build_env_client, + "managedidentity": _build_msi_client, + "managed_identity": _build_msi_client, + "clientsecret": _build_client_secret_client, + "client_secret": _build_client_secret_client, + "certificate": _build_certificate_client, + "cert": _build_certificate_client, +} def list_auth_methods() -> list[str]: @@ -412,12 +399,12 @@ def _az_connect_core( wrapped_credentials: CredentialWrapper = CredentialWrapper( chained_credential, resource_id=az_config.token_uri ) - return AzCredentials(wrapped_credentials, chained_credential) # type: ignore[arg-type] + return AzCredentials(wrapped_credentials, chained_credential) # Create the wrapped credential using the passed credential wrapped_credentials = CredentialWrapper(credential, resource_id=az_config.token_uri) return AzCredentials( - wrapped_credentials, # type: ignore[arg-type] + wrapped_credentials, ChainedTokenCredential(credential), # type: ignore[arg-type] ) @@ -514,10 +501,7 @@ def only_interactive_cred(chained_cred: ChainedTokenCredential): def _filter_credential_warning(record) -> bool: """Rewrite out credential not found message.""" - if ( - not record.name.startswith("azure.identity") - or record.levelno != logging.WARNING - ): + if not record.name.startswith("azure.identity") or record.levelno != logging.WARNING: return True message = record.getMessage() if ".get_token" in message: @@ -555,16 +539,9 @@ def check_cli_credentials() -> tuple[AzureCliStatus, str | None]: cli_profile = get_cli_profile() raw_token = cli_profile.get_raw_token() bearer_token = None - if ( - isinstance(raw_token, tuple) - and len(raw_token) == 3 - and len(raw_token[0]) == 3 - ): + if isinstance(raw_token, tuple) and len(raw_token) == 3 and len(raw_token[0]) == 3: bearer_token = raw_token[0][2] - if ( - parser.parse(bearer_token.get("expiresOn", datetime.min)) - < datetime.now() - ): + if parser.parse(bearer_token.get("expiresOn", datetime.min)) < datetime.now(): raise ValueError("AADSTS70043: The refresh token has expired") return AzureCliStatus.CLI_OK, "Azure CLI credentials available." diff --git a/msticpy/auth/cloud_mappings.py b/msticpy/auth/cloud_mappings.py index 83ed94119..c6180efed 100644 --- a/msticpy/auth/cloud_mappings.py +++ b/msticpy/auth/cloud_mappings.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Azure Cloud Mappings.""" + import contextlib -from functools import lru_cache -from typing import Any, Dict, List, Optional +from functools import cache +from typing import Any import httpx @@ -61,10 +62,8 @@ def format_endpoint(endpoint: str) -> str: return endpoint if endpoint.endswith("/") else f"{endpoint}/" -@lru_cache(maxsize=None) -def get_cloud_endpoints( - cloud: str, resource_manager_url: Optional[str] = None -) -> Dict[str, Any]: +@cache +def get_cloud_endpoints(cloud: str, resource_manager_url: str | None = None) -> dict[str, Any]: """ Get the cloud endpoints for a specific cloud. @@ -103,7 +102,7 @@ def get_cloud_endpoints( ) -def get_cloud_endpoints_by_cloud(cloud: str) -> Dict[str, Any]: +def get_cloud_endpoints_by_cloud(cloud: str) -> dict[str, Any]: """ Get the cloud endpoints for a specific cloud. @@ -124,7 +123,7 @@ def get_cloud_endpoints_by_cloud(cloud: str) -> Dict[str, Any]: def get_cloud_endpoints_by_resource_manager_url( resource_manager_url: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Get the cloud endpoints for a specific resource manager url. @@ -150,11 +149,7 @@ def get_cloud_endpoints_by_resource_manager_url( except httpx.RequestError: cloud = next( - ( - key - for key, val in CLOUD_MAPPING.items() - if val == f_resource_manager_url - ), + (key for key, val in CLOUD_MAPPING.items() if val == f_resource_manager_url), "global", ) return cloud_mappings_offline[cloud] @@ -171,7 +166,7 @@ def get_azure_config_value(key, default): return default -def default_auth_methods() -> List[str]: +def default_auth_methods() -> list[str]: """Get the default (all) authentication options.""" return get_azure_config_value( "auth_methods", ["env", "msi", "vscode", "cli", "powershell", "devicecode"] @@ -200,9 +195,9 @@ class AzureCloudConfig: def __init__( self, - cloud: Optional[str] = None, - tenant_id: Optional[str] = None, - resource_manager_url: Optional[str] = None, + cloud: str | None = None, + tenant_id: str | None = None, + resource_manager_url: str | None = None, ): """ Initialize AzureCloudConfig from `cloud` or configuration. @@ -232,14 +227,14 @@ def __init__( self.endpoints = get_cloud_endpoints(self.cloud, self.resource_manager_url) @property - def cloud_names(self) -> List[str]: + def cloud_names(self) -> list[str]: """Return a list of current cloud names.""" return list(CLOUD_MAPPING.keys()) @staticmethod def resolve_cloud_alias( alias, - ) -> Optional[str]: + ) -> str | None: """Return match of cloud alias or name.""" alias_cf = alias.casefold() aliases = {alias.casefold(): cloud for alias, cloud in CLOUD_ALIASES.items()} @@ -248,7 +243,7 @@ def resolve_cloud_alias( return alias_cf if alias_cf in aliases.values() else None @property - def suffixes(self) -> Dict[str, str]: + def suffixes(self) -> dict[str, str]: """ Get CloudSuffixes class an Azure cloud. @@ -275,9 +270,7 @@ def token_uri(self) -> str: @property def authority_uri(self) -> str: """Return the AAD authority URI.""" - return format_endpoint( - self.endpoints.get("authentication", {}).get("loginEndpoint") - ) + return format_endpoint(self.endpoints.get("authentication", {}).get("loginEndpoint")) @property def log_analytics_uri(self) -> str: diff --git a/msticpy/auth/cred_wrapper.py b/msticpy/auth/cred_wrapper.py index a1284aac7..f2c0c5e19 100644 --- a/msticpy/auth/cred_wrapper.py +++ b/msticpy/auth/cred_wrapper.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Credential wrapper to expose ADAL and MSAL credentials.""" -from typing import Any, Dict + +from typing import Any from azure.core.pipeline import PipelineContext, PipelineRequest from azure.core.pipeline.policies import BearerTokenCredentialPolicy @@ -42,7 +43,7 @@ def __init__( """ super().__init__(None) # type: ignore - self.token: Dict[str, Any] = {} + self.token: dict[str, Any] = {} if credential is None: credential = DefaultAzureCredential() diff --git a/msticpy/auth/keyring_client.py b/msticpy/auth/keyring_client.py index 6f280d2f2..477dae59b 100644 --- a/msticpy/auth/keyring_client.py +++ b/msticpy/auth/keyring_client.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Settings provider for secrets.""" -from typing import Any, Set + +from typing import Any import keyring from keyring.errors import KeyringError, KeyringLocked @@ -34,7 +35,7 @@ def __init__(self, name: str = "key-cache", debug: bool = False): """ self.debug = debug self.keyring = name - self._secret_names: Set[str] = set() + self._secret_names: set[str] = set() def __getitem__(self, key: str): """Get key name.""" diff --git a/msticpy/auth/keyvault_client.py b/msticpy/auth/keyvault_client.py index f0e6ef6a4..d0924bb92 100644 --- a/msticpy/auth/keyvault_client.py +++ b/msticpy/auth/keyvault_client.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Keyvault client - adapted from Bluehound code.""" + from __future__ import annotations import logging @@ -322,10 +323,7 @@ def get_secret(self: Self, secret_name: str) -> str: "Secret was empty in vault %s", self.vault_uri, ) - err_msg = ( - f"Secret name {secret_name} in {self.vault_uri}" - "has blank or null value." - ) + err_msg = f"Secret name {secret_name} in {self.vault_uri}has blank or null value." raise MsticpyKeyVaultMissingSecretError( err_msg, title=f"secret {secret_name} empty.", diff --git a/msticpy/auth/keyvault_settings.py b/msticpy/auth/keyvault_settings.py index e00a81318..2018e6f3a 100644 --- a/msticpy/auth/keyvault_settings.py +++ b/msticpy/auth/keyvault_settings.py @@ -5,8 +5,10 @@ # -------------------------------------------------------------------------- """Keyvault client settings.""" +from __future__ import annotations + import warnings -from typing import Any, List, Optional +from typing import Any from .._version import VERSION from ..common import pkg_config as config @@ -61,8 +63,8 @@ def __init__(self): msticpyconfig.yaml. """ - self.authority: Optional[str] = None - self.auth_methods: List[str] = [] + self.authority: str | None = None + self.auth_methods: list[str] = [] try: kv_config = config.get_config("KeyVault") except KeyError as err: @@ -120,27 +122,25 @@ def authority_uri(self) -> str: return self.az_cloud_config.authority_uri @property - def keyvault_uri(self) -> Optional[str]: + def keyvault_uri(self) -> str | None: """Return KeyVault URI template for current cloud.""" suffix = self.az_cloud_config.suffixes.get("keyVaultDns") kv_uri = f"https://{{vault}}.{suffix}" if not kv_uri: mssg = f"Could not find a valid KeyVault endpoint for {self.cloud}" - warnings.warn(mssg) + warnings.warn(mssg, stacklevel=2) return kv_uri @property - def mgmt_uri(self) -> Optional[str]: + def mgmt_uri(self) -> str | None: """Return Azure management URI template for current cloud.""" mgmt_uri = self.az_cloud_config.resource_manager if not mgmt_uri: mssg = f"Could not find a valid KeyVault endpoint for {self.cloud}" - warnings.warn(mssg) + warnings.warn(mssg, stacklevel=2) return mgmt_uri - def get_tenant_authority_uri( - self, authority_uri: str = None, tenant: str = None - ) -> str: + def get_tenant_authority_uri(self, authority_uri: str = None, tenant: str = None) -> str: """ Return authority URI for tenant. @@ -177,9 +177,7 @@ def get_tenant_authority_uri( return f"{auth}{tenant.strip()}" return f"{auth}/{tenant.strip()}" - def get_tenant_authority_host( - self, authority_uri: str = None, tenant: str = None - ) -> str: + def get_tenant_authority_host(self, authority_uri: str = None, tenant: str = None) -> str: """ Return tenant authority URI with no leading scheme. diff --git a/msticpy/auth/msal_auth.py b/msticpy/auth/msal_auth.py index c84f0b979..a49552ddc 100644 --- a/msticpy/auth/msal_auth.py +++ b/msticpy/auth/msal_auth.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """MSAL delegated app authentication class.""" + import json import logging from sys import platform diff --git a/msticpy/auth/secret_settings.py b/msticpy/auth/secret_settings.py index 4ef26b7fa..91a80b9e0 100644 --- a/msticpy/auth/secret_settings.py +++ b/msticpy/auth/secret_settings.py @@ -4,9 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Settings provider for secrets.""" + import re +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any from .._version import VERSION from ..common import pkg_config as config @@ -34,7 +36,7 @@ def __init__( self, tenant_id: str = None, use_keyring: bool = False, - auth_methods: Optional[List[str]] = None, + auth_methods: list[str] | None = None, credential: Any = None, **kwargs, ): @@ -96,8 +98,8 @@ def __init__( "Please add this to the KeyVault section of msticpyconfig.yaml", title="missing tenant ID value.", ) - self.kv_secret_vault: Dict[str, str] = {} - self.kv_vaults: Dict[str, BHKeyVaultClient] = {} + self.kv_secret_vault: dict[str, str] = {} + self.kv_vaults: dict[str, BHKeyVaultClient] = {} self._use_keyring = ( _KEYRING_INSTALLED and KeyringClient.is_keyring_available() @@ -139,9 +141,7 @@ def format_kv_name(setting_path): """Return normalized name for use as a KeyVault secret name.""" return re.sub("[^0-9a-zA-Z-]", "-", setting_path) - def _get_kv_vault_and_name( - self, setting_path: str - ) -> Tuple[Optional[str], Optional[str]]: + def _get_kv_vault_and_name(self, setting_path: str) -> tuple[str | None, str | None]: """Return the vault and secret name for a config path.""" setting_item = config.get_config(setting_path, None) diff --git a/msticpy/common/azure_auth.py b/msticpy/common/azure_auth.py deleted file mode 100644 index 2683256da..000000000 --- a/msticpy/common/azure_auth.py +++ /dev/null @@ -1,27 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module common.azure_auth has moved. - -See :py:mod:`msticpy.auth` -""" -import warnings - -from .._version import VERSION - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..auth.azure_auth import * # noqa: F401 - -__version__ = VERSION -__author__ = "Pete Bryan" - -WARN_MSSG = ( - "This module has moved to msticpy.auth\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/common/check_version.py b/msticpy/common/check_version.py index 727e7601e..08c6045dd 100644 --- a/msticpy/common/check_version.py +++ b/msticpy/common/check_version.py @@ -4,10 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Check current version against PyPI.""" + from importlib.metadata import version -from packaging.version import Version -from packaging.version import parse as parse_version +from packaging.version import Version # pylint: disable=no-name-in-module +from packaging.version import parse as parse_version # pylint: disable=no-name-in-module from .._version import VERSION diff --git a/msticpy/common/data_types.py b/msticpy/common/data_types.py index 1ae805e25..ab875ef01 100644 --- a/msticpy/common/data_types.py +++ b/msticpy/common/data_types.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Object container class.""" -from typing import Any, Dict, Optional, Type + +from typing import Any from .._version import VERSION from ..common.utility import check_kwarg @@ -16,7 +17,7 @@ class ObjectContainer: """Empty class used to create hierarchical attributes.""" - _subclasses: Dict[str, Type] = {} + _subclasses: dict[str, type] = {} def __len__(self): """Return number of items in the attribute collection.""" @@ -35,7 +36,7 @@ def __getattr__(self, name): pass else: return attr - nm_err: Optional[Exception] = None + nm_err: Exception | None = None try: # check for similar-named attributes in __dict__ check_kwarg(name, list(self.__dict__.keys())) @@ -45,9 +46,7 @@ def __getattr__(self, name): raise AttributeError( f"{self.__class__.__name__} object has no attribute {name}" ) from nm_err - raise AttributeError( - f"{self.__class__.__name__} object has no attribute {name}" - ) + raise AttributeError(f"{self.__class__.__name__} object has no attribute {name}") def __repr__(self): """Return list of attributes.""" diff --git a/msticpy/common/data_utils.py b/msticpy/common/data_utils.py index 9f45e5b5e..ca4bb933e 100644 --- a/msticpy/common/data_utils.py +++ b/msticpy/common/data_utils.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- """Data utility functions.""" -from typing import List, Union import pandas as pd @@ -16,7 +15,7 @@ def ensure_df_datetimes( data: pd.DataFrame, - columns: Union[str, List[str], None] = None, + columns: str | list[str] | None = None, add_utc_tz: bool = True, ) -> pd.DataFrame: """ @@ -42,7 +41,7 @@ def ensure_df_datetimes( """ if not columns: - columns = list(data.filter(regex=".*[Tt]ime.*").columns) # type: ignore + columns = list(data.filter(regex=".*[Tt]ime.*").columns) if isinstance(columns, str): columns = [columns] col_map = { @@ -54,9 +53,7 @@ def ensure_df_datetimes( # Look for any TZ-naive columns in the list if add_utc_tz: - localize_cols = { - col for col in columns if col in data.select_dtypes("datetime") - } + localize_cols = {col for col in columns if col in data.select_dtypes("datetime")} for col in localize_cols: converted_data[col] = converted_data[col].dt.tz_localize( "UTC", ambiguous="infer", nonexistent="shift_forward" diff --git a/msticpy/common/exceptions.py b/msticpy/common/exceptions.py index 9d7a293fa..fd9e553f1 100644 --- a/msticpy/common/exceptions.py +++ b/msticpy/common/exceptions.py @@ -4,12 +4,14 @@ # license information. # -------------------------------------------------------------------------- """Miscellaneous helper methods for Jupyter Notebooks.""" + from __future__ import annotations import contextlib import sys import traceback -from typing import Any, ClassVar, Generator +from collections.abc import Generator +from typing import Any, ClassVar from IPython.display import display @@ -207,9 +209,7 @@ def _get_exception_text(self) -> str: if isinstance(l_content, tuple): l_content = l_content[0] if l_type == "title": - out_lines.extend( - ("-" * len(l_content), l_content, "-" * len(l_content)) - ) + out_lines.extend(("-" * len(l_content), l_content, "-" * len(l_content))) elif l_type == "uri": if isinstance(l_content, tuple): out_lines.append(f" - {': '.join(l_content)}") @@ -247,9 +247,7 @@ class MsticpyUserConfigError(MsticpyUserError): "https://msticpy.readthedocs.io/en/latest/getting_started/msticpyconfig.html", ) - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create generic user configuration exception. @@ -284,9 +282,7 @@ class MsticpyKeyVaultConfigError(MsticpyUserConfigError): "#specifying-secrets-as-key-vault-secrets", ) - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create Key Vault configuration exception. @@ -308,9 +304,7 @@ def __init__( class MsticpyKeyVaultMissingSecretError(MsticpyKeyVaultConfigError): """Missing secret exception.""" - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create Key Vault missing key exception. @@ -338,9 +332,7 @@ class MsticpyAzureConfigError(MsticpyUserConfigError): + "#instantiating-and-connecting-with-an-azure-data-connector", ) - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create Azure data missing configuration exception. @@ -411,9 +403,7 @@ class MsticpyImportExtraError(MsticpyUserError, ImportError): "https://msticpy.readthedocs.io/en/latest/getting_started/Installing.html", ) - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create import missing extra exception. @@ -449,9 +439,7 @@ class MsticpyMissingDependencyError(MsticpyUserError, ImportError): "https://msticpy.readthedocs.io/en/latest/getting_started/Installing.html", ) - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create import missing extra exception. @@ -501,9 +489,7 @@ class MsticpyParameterError(MsticpyUserError): "https://msticpy.readthedocs.io", ) - def __init__( - self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs - ) -> None: + def __init__(self, *args, help_uri: tuple[str, str] | str | None = None, **kwargs) -> None: """ Create parameter exception. diff --git a/msticpy/common/pkg_config.py b/msticpy/common/pkg_config.py index 222bc78c6..f785ddbde 100644 --- a/msticpy/common/pkg_config.py +++ b/msticpy/common/pkg_config.py @@ -12,14 +12,16 @@ a file `msticpyconfig.yaml` in the current directory. """ + import contextlib import numbers import os from collections import UserDict +from collections.abc import Callable from contextlib import AbstractContextManager from importlib.util import find_spec from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any import httpx import yaml @@ -69,11 +71,11 @@ def get(self, key, default=None): _settings = SettingsDict() -def _get_current_config() -> Callable[[Any], Optional[str]]: +def _get_current_config() -> Callable[[Any], str | None]: """Closure for holding path of config file.""" - _current_conf_file: Optional[str] = None + _current_conf_file: str | None = None - def _current_config(file_path: Optional[str] = None) -> Optional[str]: + def _current_config(file_path: str | None = None) -> str | None: nonlocal _current_conf_file # noqa if file_path is not None: _current_conf_file = file_path @@ -85,7 +87,7 @@ def _current_config(file_path: Optional[str] = None) -> Optional[str]: _CURRENT_CONF_FILE = _get_current_config() -def current_config_path() -> Optional[str]: +def current_config_path() -> str | None: """ Return the path of the current config file, if any. @@ -106,7 +108,7 @@ def get_settings(): def refresh_config(): """Re-read the config settings.""" # pylint: disable=global-statement - global _default_settings, _custom_settings, _settings + global _default_settings, _custom_settings, _settings # noqa: PLW0603 _default_settings = _get_default_config() _custom_settings = _get_custom_config() _custom_settings = _create_data_providers(_custom_settings) @@ -125,9 +127,7 @@ def has_config(setting_path: str) -> bool: _DEFAULT_SENTINEL = "@@@NO-DEFAULT-VALUE@@@" -def get_config( - setting_path: Optional[str] = None, default: Any = _DEFAULT_SENTINEL -) -> Any: +def get_config(setting_path: str | None = None, default: Any = _DEFAULT_SENTINEL) -> Any: """ Return setting item for path. @@ -242,7 +242,7 @@ def _del_config(setting_path: str, settings_dict: SettingsDict) -> Any: return current_value -def _read_config_file(config_file: Union[str, Path]) -> SettingsDict: +def _read_config_file(config_file: str | Path) -> SettingsDict: """ Read a yaml config definition file. @@ -258,7 +258,7 @@ def _read_config_file(config_file: Union[str, Path]) -> SettingsDict: """ if Path(config_file).is_file(): - with open(config_file, "r", encoding="utf-8") as f_handle: + with open(config_file, encoding="utf-8") as f_handle: # use safe_load instead of load try: return SettingsDict(yaml.safe_load(f_handle)) @@ -272,9 +272,7 @@ def _read_config_file(config_file: Union[str, Path]) -> SettingsDict: return SettingsDict() -def _consolidate_configs( - def_config: SettingsDict, cust_config: SettingsDict -) -> SettingsDict: +def _consolidate_configs(def_config: SettingsDict, cust_config: SettingsDict) -> SettingsDict: resultant_config = SettingsDict() resultant_config.update(def_config) @@ -286,31 +284,31 @@ def _override_config(base_config: SettingsDict, new_config: SettingsDict): for c_key, c_item in new_config.items(): if c_item is None: continue - if isinstance(base_config.get(c_key), (dict, SettingsDict)): - _override_config(base_config[c_key], new_config[c_key]) + if isinstance(base_config.get(c_key), dict | SettingsDict): + _override_config(base_config[c_key], c_item) else: - base_config[c_key] = new_config[c_key] + base_config[c_key] = c_item def _get_default_config(): """Return the package default config file.""" package = "msticpy" try: - from importlib.resources import ( # pylint: disable=import-outside-toplevel + from importlib.resources import ( # pylint: disable=import-outside-toplevel # noqa: PLC0415 as_file, files, ) - package_path: AbstractContextManager = as_file( - files(package).joinpath(_CONFIG_FILE) - ) + package_path: AbstractContextManager = as_file(files(package).joinpath(_CONFIG_FILE)) except ImportError: # If importlib.resources is not available we fall back to # older Python method - from importlib.resources import path # pylint: disable=import-outside-toplevel + from importlib.resources import ( # pylint: disable=import-outside-toplevel # noqa: PLC0415 + path, + ) # pylint: disable=deprecated-method - package_path = path(package, _CONFIG_FILE) # noqa: W4902 + package_path = path(package, _CONFIG_FILE) try: with package_path as config_path: @@ -318,7 +316,7 @@ def _get_default_config(): except ModuleNotFoundError as mod_err: # if all else fails we try to find the package default config somewhere # in the package tree - we use the first one we find - pkg_root: Optional[Path] = _get_pkg_path("msticpy") + pkg_root: Path | None = _get_pkg_path("msticpy") if not pkg_root: raise MsticpyUserConfigError( f"Unable to locate the package default {_CONFIG_FILE}", @@ -355,7 +353,7 @@ def _get_pkg_path(pkg_name): return current_path -def _create_data_providers(mp_config: Dict[str, Any]) -> Dict[str, Any]: +def _create_data_providers(mp_config: dict[str, Any]) -> dict[str, Any]: if mp_config.get(_DP_KEY) is None: mp_config[_DP_KEY] = {} data_providers = mp_config[_DP_KEY] @@ -379,28 +377,26 @@ def _create_data_providers(mp_config: Dict[str, Any]) -> Dict[str, Any]: def get_http_timeout( *, - timeout: Optional[int] = None, - def_timeout: Optional[int] = None, + timeout: int | None = None, + def_timeout: int | None = None, **kwargs, ) -> httpx.Timeout: """Return timeout from settings or overridden in `kwargs`.""" del kwargs - config_timeout: Union[int, Dict, httpx.Timeout, List, Tuple] = get_config( + config_timeout: int | dict | httpx.Timeout | list | tuple = get_config( "msticpy.http_timeout", get_config("http_timeout", None) ) - timeout_params: Union[int, Dict, httpx.Timeout, List[Union[float, None]], Tuple] = ( + timeout_params: int | dict | httpx.Timeout | list[float | None] | tuple = ( timeout or def_timeout or config_timeout ) if isinstance(timeout_params, dict): - timeout_params = { - name: _valid_timeout(val) for name, val in timeout_params.items() - } + timeout_params = {name: _valid_timeout(val) for name, val in timeout_params.items()} return httpx.Timeout(**timeout_params) if isinstance(timeout_params, httpx.Timeout): return timeout_params if isinstance(timeout_params, numbers.Real): return httpx.Timeout(_valid_timeout(timeout_params)) - if isinstance(timeout_params, (list, tuple)): + if isinstance(timeout_params, list | tuple): timeout_params = [_valid_timeout(val) for val in timeout_params] if len(timeout_params) >= 2: return httpx.Timeout(timeout=timeout_params[0], connect=timeout_params[1]) @@ -410,8 +406,8 @@ def get_http_timeout( def _valid_timeout( - timeout_val: Optional[Union[float, numbers.Real]] -) -> Union[float, None]: + timeout_val: float | numbers.Real | None, +) -> float | None: """Return float in valid range or None.""" if isinstance(timeout_val, numbers.Real) and float(timeout_val) >= 0.0: return float(timeout_val) @@ -426,7 +422,7 @@ def _valid_timeout( def validate_config( - mp_config: Union[SettingsDict, Dict[str, Any], None] = None, config_file: str = None + mp_config: SettingsDict | dict[str, Any] | None = None, config_file: str = None ): """ Validate msticpy config settings. @@ -445,7 +441,7 @@ def validate_config( if not mp_config and not config_file: mp_config = _settings - if not isinstance(mp_config, (dict, SettingsDict)): + if not isinstance(mp_config, dict | SettingsDict): raise TypeError("Unknown format for configuration settings.") mp_errors, mp_warn = _validate_azure_sentinel(mp_config=mp_config) @@ -477,16 +473,12 @@ def validate_config( def _print_validation_report(mp_errors, mp_warn): if mp_errors: - _print_validation_item( - "\nThe following configuration errors were found:", mp_errors - ) + _print_validation_item("\nThe following configuration errors were found:", mp_errors) else: print("No errors found.") if mp_warn: - _print_validation_item( - "\nThe following configuration warnings were found:", mp_warn - ) + _print_validation_item("\nThe following configuration warnings were found:", mp_warn) else: print("No warnings found.") @@ -511,7 +503,7 @@ def _validate_azure_sentinel(mp_config): mp_errors.append("Missing or empty 'Workspaces' key in 'AzureSentinel' section") return mp_errors, mp_warnings no_default = True - for ws, ws_settings in ws_settings.items(): + for ws, ws_settings in ws_settings.items(): # noqa: B020 if ws == "Default": no_default = False ws_id = ws_settings.get("WorkspaceId") @@ -548,9 +540,7 @@ def _check_provider_settings(mp_config, section, key_provs): _check_required_provider_settings(sec_args, sec_path, p_name, key_provs) ) - mp_errors.extend( - _check_env_vars(args_key=p_setting.get("Args"), section=sec_path) - ) + mp_errors.extend(_check_env_vars(args_key=p_setting.get("Args"), section=sec_path)) return mp_errors, mp_warnings @@ -576,11 +566,7 @@ def _check_required_provider_settings(sec_args, sec_path, p_name, key_provs): ) ) - if ( - p_name == _AZ_CLI - and "clientId" in sec_args - and sec_args["clientId"] is not None - ): + if p_name == _AZ_CLI and "clientId" in sec_args and sec_args["clientId"] is not None: # only warn if partially filled - since these are optional errs.extend( ( diff --git a/msticpy/common/provider_settings.py b/msticpy/common/provider_settings.py index e8382c0bf..2fdd3e2d5 100644 --- a/msticpy/common/provider_settings.py +++ b/msticpy/common/provider_settings.py @@ -4,13 +4,15 @@ # license information. # -------------------------------------------------------------------------- """Helper functions for configuration settings.""" + from __future__ import annotations import os import warnings from collections import UserDict +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Callable +from typing import Any from .._version import VERSION from .exceptions import MsticpyImportExtraError @@ -55,7 +57,7 @@ def _secrets_enabled() -> bool: return _SECRETS_ENABLED and _SECRETS_CLIENT -def get_secrets_client_func() -> Callable[..., "SecretsClient" | None]: +def get_secrets_client_func() -> Callable[..., SecretsClient | None]: """ Return function to get or create secrets client. @@ -78,11 +80,11 @@ def get_secrets_client_func() -> Callable[..., "SecretsClient" | None]: replace the SecretsClient instance and return that. """ - _secrets_client: "SecretsClient" | None = None + _secrets_client: SecretsClient | None = None def _return_secrets_client( - secrets_client: "SecretsClient" | None = None, **kwargs - ) -> "SecretsClient" | None: + secrets_client: SecretsClient | None = None, **kwargs + ) -> SecretsClient | None: """Return (optionally setting or creating) a SecretsClient.""" nonlocal _secrets_client if not _SECRETS_ENABLED: @@ -101,7 +103,7 @@ def _return_secrets_client( # pylint: disable=invalid-name _SECRETS_CLIENT: Any = None # Create the secrets client closure -_SET_SECRETS_CLIENT: Callable[..., "SecretsClient" | None] = get_secrets_client_func() +_SET_SECRETS_CLIENT: Callable[..., SecretsClient | None] = get_secrets_client_func() # Create secrets client instance if SecretsClient can be imported # and config has KeyVault settings. if get_config("KeyVault", None) and _SECRETS_ENABLED: @@ -124,13 +126,11 @@ def get_provider_settings(config_section="TIProviders") -> dict[str, ProviderSet """ # pylint: disable=global-statement - global _SECRETS_CLIENT + global _SECRETS_CLIENT # noqa: PLW0603 # pylint: enable=global-statement if get_config("KeyVault", None): if _SECRETS_CLIENT is None and _SECRETS_ENABLED: - print( - "KeyVault enabled. Secrets access may require additional authentication." - ) + print("KeyVault enabled. Secrets access may require additional authentication.") _SECRETS_CLIENT = _SET_SECRETS_CLIENT() else: _SECRETS_CLIENT = None @@ -141,7 +141,7 @@ def get_provider_settings(config_section="TIProviders") -> dict[str, ProviderSet settings = {} for provider, item_settings in section_settings.items(): prov_args = item_settings.get("Args") - prov_settings = ProviderSettings( # type: ignore[call-arg] + prov_settings = ProviderSettings( name=provider, description=item_settings.get("Description"), args=_get_setting_args( @@ -290,7 +290,10 @@ def _get_protected_settings( f"{setting_path}.{arg_name}", arg_value ) except NotImplementedError: - warnings.warn(f"Setting type for setting {arg_value} not yet implemented. ") + warnings.warn( + f"Setting type for setting {arg_value} not yet implemented. ", + stacklevel=2, + ) return setting_dict @@ -321,7 +324,7 @@ def _fetch_secret_setting( _description_ """ - if isinstance(config_setting, (str, int, float)): + if isinstance(config_setting, str | int | float): return str(config_setting) if not isinstance(config_setting, dict): err_msg: str = ( @@ -336,7 +339,8 @@ def _fetch_secret_setting( warnings.warn( f"Environment variable {config_setting['EnvironmentVar']}" f" ({setting_path})" - " was not set" + " value not found.", + stacklevel=2, ) return env_value if "KeyVault" in config_setting: diff --git a/msticpy/common/proxy_settings.py b/msticpy/common/proxy_settings.py index f46be434e..fcc1cd9eb 100644 --- a/msticpy/common/proxy_settings.py +++ b/msticpy/common/proxy_settings.py @@ -29,13 +29,12 @@ - KeyVault: vault_name/secret_name """ -from typing import Dict, Optional from .pkg_config import get_config from .provider_settings import get_protected_setting -def get_http_proxies() -> Optional[Dict[str, str]]: +def get_http_proxies() -> dict[str, str] | None: """Return proxy settings from config.""" proxy_config = get_config("msticpy.Proxies", None) if not proxy_config: diff --git a/msticpy/common/timespan.py b/msticpy/common/timespan.py index d8b3a828d..c76beb4c2 100644 --- a/msticpy/common/timespan.py +++ b/msticpy/common/timespan.py @@ -5,14 +5,13 @@ # -------------------------------------------------------------------------- """Timespan class.""" - import contextlib from datetime import datetime, timedelta, timezone from numbers import Number -from typing import Any, Optional, Tuple, Union +from typing import Any, Union import pandas as pd -from dateutil.parser import ParserError # type: ignore +from dateutil.parser import ParserError from .._version import VERSION @@ -27,10 +26,10 @@ class TimeSpan: def __init__( self, *args, - timespan: Optional[Union["TimeSpan", Tuple[Any, Any], Any]] = None, - start: Optional[Union[datetime, str]] = None, - end: Optional[Union[datetime, str]] = None, - period: Optional[Union[timedelta, str]] = None, + timespan: Union["TimeSpan", tuple[Any, Any], Any] | None = None, + start: datetime | str | None = None, + end: datetime | str | None = None, + period: timedelta | str | None = None, ): """ Initialize Timespan. @@ -153,7 +152,7 @@ def _process_args(*args, timespan, start, end, period): timespan = args[0] # e.g. a tuple of start, end if len(args) == 2: start = args[0] - if isinstance(args[1], (str, datetime)): + if isinstance(args[1], str | datetime): end = args[1] elif isinstance(args[1], Number): period = args[1] diff --git a/msticpy/common/utility/__init__.py b/msticpy/common/utility/__init__.py index 8b40950c1..9dd699156 100644 --- a/msticpy/common/utility/__init__.py +++ b/msticpy/common/utility/__init__.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Utility sub-package.""" + from ..._version import VERSION from .format import * # noqa: F401, F403 from .package import * # noqa: F401, F403 diff --git a/msticpy/common/utility/format.py b/msticpy/common/utility/format.py index f507dfc64..841e876fc 100644 --- a/msticpy/common/utility/format.py +++ b/msticpy/common/utility/format.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Formatting and checking functions.""" + import builtins import re import uuid diff --git a/msticpy/common/utility/ipython.py b/msticpy/common/utility/ipython.py index 1e73f5502..93cbfa635 100644 --- a/msticpy/common/utility/ipython.py +++ b/msticpy/common/utility/ipython.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Notebook utility functions.""" + # pickle only used here for storing data. import pickle # nosec from base64 import b64encode -from typing import Any, Iterable, Optional, Union +from collections.abc import Iterable +from typing import Any from IPython import get_ipython from IPython.display import HTML, DisplayHandle, display @@ -23,9 +25,9 @@ @export def md( string: str, - styles: Union[str, Iterable[str]] = None, - disp_id: Optional[Union[bool, DisplayHandle]] = None, -) -> Optional[DisplayHandle]: + styles: str | Iterable[str] | None = None, + disp_id: bool | DisplayHandle | None = None, +) -> DisplayHandle | None: """ Display a string as Markdown with optional style. @@ -73,7 +75,7 @@ def md( @export -def md_warn(string: str, disp_id: Optional[DisplayHandle] = None): +def md_warn(string: str, disp_id: DisplayHandle | None = None): """ Return string as a warning - orange text prefixed by "Warning". @@ -99,7 +101,7 @@ def md_warn(string: str, disp_id: Optional[DisplayHandle] = None): @export -def md_error(string: str, disp_id: Optional[DisplayHandle] = None): +def md_error(string: str, disp_id: DisplayHandle | None = None): """ Return string as an error - red text prefixed by "Error". diff --git a/msticpy/common/utility/package.py b/msticpy/common/utility/package.py index 1ca2666aa..1fbb9531b 100644 --- a/msticpy/common/utility/package.py +++ b/msticpy/common/utility/package.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Packaging utility functions.""" + import importlib import os import re @@ -13,7 +14,6 @@ from importlib.metadata import PackageNotFoundError, version from pathlib import Path from platform import python_version -from typing import Dict, List, Optional, Tuple, Union from IPython.core.display import HTML from IPython.core.getipython import get_ipython @@ -60,13 +60,13 @@ def resolve_pkg_path(part_path: str): ) if not searched_paths or len(searched_paths) > 1: - warnings.warn(f"No path or ambiguous match for {part_path} not found") + warnings.warn(f"No path or ambiguous match for {part_path} not found", stacklevel=2) return None return str(searched_paths[0]) @export -def check_py_version(min_ver: Tuple = (3, 6)): +def check_py_version(min_ver: tuple = (3, 6)): """ Check that the current python version is not less than `min_ver`. @@ -76,7 +76,7 @@ def check_py_version(min_ver: Tuple = (3, 6)): Minimum required version, by default (3,6) """ - if isinstance(min_ver, (float, str)): + if isinstance(min_ver, float | str): min_ver_list = str(min_ver).split(".") min_ver = (int(min_ver_list[0]), int(min_ver_list[1])) if sys.version_info < min_ver: @@ -86,9 +86,9 @@ def check_py_version(min_ver: Tuple = (3, 6)): # pylint: disable=not-an-iterable, too-many-branches -@export # noqa: MC0001 -def check_and_install_missing_packages( # noqa: MC0001 - required_packages: List[str], +@export +def check_and_install_missing_packages( + required_packages: list[str], force_notebook: bool = False, user: bool = False, upgrade: bool = False, @@ -161,8 +161,7 @@ def check_and_install_missing_packages( # noqa: MC0001 subprocess.run( # nosec pkg_command + [package], check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, ) except subprocess.CalledProcessError as proc_err: print(f"An Error has occurred while installing {package}.") @@ -187,15 +186,13 @@ def _get_mp_ua(): @export -def mp_ua_header() -> Dict[str, str]: +def mp_ua_header() -> dict[str, str]: """Return headers dict for MSTICPy User Agent.""" return {"UserAgent": _get_mp_ua()} @export -def search_for_file( - pattern: str, paths: List[Union[str, Path]] = None -) -> Optional[str]: +def search_for_file(pattern: str, paths: list[str | Path] = None) -> str | None: """Search `paths` for file `pattern`.""" paths = paths or [".", ".."] for start_path in paths: @@ -206,7 +203,7 @@ def search_for_file( @export -def search_module(pattern: str) -> Dict[str, str]: +def search_module(pattern: str) -> dict[str, str]: """ Return MSTICPy modules that match `pattern`. @@ -313,7 +310,7 @@ def set_unit_testing(on: bool = True): os.environ.pop(_U_TEST_ENV, None) -def init_getattr(module_name: str, dynamic_imports: Dict[str, str], attrib: str): +def init_getattr(module_name: str, dynamic_imports: dict[str, str], attrib: str): """Import and return dynamic attribute.""" if attrib in dynamic_imports: module = importlib.import_module(dynamic_imports[attrib]) @@ -321,7 +318,7 @@ def init_getattr(module_name: str, dynamic_imports: Dict[str, str], attrib: str) raise AttributeError(f"{module_name} has no attribute {attrib}") -def init_dir(static_attribs: List[str], dynamic_imports: Dict[str, str]): +def init_dir(static_attribs: list[str], dynamic_imports: dict[str, str]): """Return list of available attributes.""" return sorted(set(static_attribs + list(dynamic_imports))) @@ -336,8 +333,6 @@ def import_item(*args, **kwargs): if attribute is None: imp_module = importlib.import_module(module) attribute = getattr(imp_module, attrib) - return ( - attribute(*args, **kwargs) if (call and callable(attribute)) else attribute - ) + return attribute(*args, **kwargs) if (call and callable(attribute)) else attribute return import_item diff --git a/msticpy/common/utility/types.py b/msticpy/common/utility/types.py index 48915816d..9f79dd1d9 100644 --- a/msticpy/common/utility/types.py +++ b/msticpy/common/utility/types.py @@ -4,15 +4,17 @@ # license information. # -------------------------------------------------------------------------- """Utility classes and functions.""" + from __future__ import annotations import difflib import inspect import sys +from collections.abc import Callable, Iterable from enum import Enum from functools import wraps from types import ModuleType -from typing import Any, Callable, Iterable, TypeVar, overload +from typing import Any, TypeVar, overload from typing_extensions import Self @@ -25,22 +27,22 @@ @overload -def export(obj: type[T]) -> type[T]: ... # noqa: E704 +def export(obj: type[T]) -> type[T]: ... @overload -def export(obj: Callable) -> Callable: ... # noqa: E704 +def export(obj: Callable) -> Callable: ... def export(obj: type | Callable) -> type | Callable: """Decorate function or class to export to __all__.""" mod: ModuleType = sys.modules[obj.__module__] if hasattr(mod, "__all__"): - all_list: list[str] = getattr(mod, "__all__") + all_list: list[str] = mod.__all__ all_list.append(obj.__name__) else: all_list = [obj.__name__] - setattr(mod, "__all__", all_list) + mod.__all__ = all_list # type: ignore[attr-defined] return obj diff --git a/msticpy/common/wsconfig.py b/msticpy/common/wsconfig.py index 36cd7b7b6..7e720624a 100644 --- a/msticpy/common/wsconfig.py +++ b/msticpy/common/wsconfig.py @@ -4,12 +4,13 @@ # license information. # -------------------------------------------------------------------------- """Module for Log Analytics-related configuration.""" + import contextlib import json import os import re from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import ipywidgets as widgets from IPython.display import display @@ -118,10 +119,10 @@ class WorkspaceConfig: def __init__( self, - workspace: Optional[str] = None, - config_file: Optional[str] = None, + workspace: str | None = None, + config_file: str | None = None, interactive: bool = True, - config: Optional[Dict[str, str]] = None, + config: dict[str, str] | None = None, ): """ Load current Azure Notebooks configuration for Log Analytics. @@ -144,11 +145,11 @@ def __init__( Workspace configuration as dictionary. """ - self._config: Dict[str, Any] = {} + self._config: dict[str, Any] = {} self._interactive = interactive self._config_file = config_file self.workspace_key = workspace or "Default" - self.settings_key: Optional[str] = None + self.settings_key: str | None = None # If config file specified, use that if config: @@ -162,9 +163,7 @@ def __getattr__(self, attribute: str): """Return attribute from configuration.""" with contextlib.suppress(KeyError): return self[attribute] - raise AttributeError( - f"{self.__class__.__name__} has no attribute '{attribute}'" - ) + raise AttributeError(f"{self.__class__.__name__} has no attribute '{attribute}'") def __getitem__(self, key: str): """Allow property get using dictionary key syntax.""" @@ -212,7 +211,7 @@ def config_loaded(self) -> bool: """ ws_value = self._config.get(self.CONF_WS_ID, None) ten_value = self._config.get(self.CONF_TENANT_ID, None) - return is_valid_uuid(ws_value) and is_valid_uuid(ten_value) # type: ignore + return is_valid_uuid(ws_value) and is_valid_uuid(ten_value) @property def code_connect_str(self) -> str: @@ -229,8 +228,7 @@ def code_connect_str(self) -> str: ws_id = self[self.CONF_WS_ID] if not ten_id: raise KeyError( - f"Configuration setting for {self.CONF_TENANT_ID} " - + "could not be found." + f"Configuration setting for {self.CONF_TENANT_ID} " + "could not be found." ) if not ws_id: raise KeyError( @@ -251,27 +249,27 @@ def mp_settings(self): } @property - def args(self) -> Dict[str, str]: + def args(self) -> dict[str, str]: """Return any additional arguments.""" return self._config.get(self.CONF_ARGS, {}) @property - def settings_path(self) -> Optional[str]: + def settings_path(self) -> str | None: """Return the path to the settings in the MSTICPY config.""" if self.settings_key: return f"AzureSentinel.Workspaces.{self.settings_key}" return None @property - def settings(self) -> Dict[str, Any]: + def settings(self) -> dict[str, Any]: """Return the current settings dictionary.""" return get_config(self.settings_path, {}) @classmethod - def from_settings(cls, settings: Dict[str, Any]) -> "WorkspaceConfig": + def from_settings(cls, settings: dict[str, Any]) -> "WorkspaceConfig": """Create a WorkstationConfig from MSTICPY Workspace settings.""" return cls( - config={ # type: ignore + config={ cls.CONF_WS_NAME: settings.get(cls.CONF_WS_NAME), # type: ignore cls.CONF_SUB_ID: settings.get(cls.CONF_SUB_ID), # type: ignore cls.CONF_WS_ID: settings.get(cls.CONF_WS_ID), # type: ignore @@ -298,9 +296,7 @@ def from_connection_string(cls, connection_str: str) -> "WorkspaceConfig": tenant_id = match.groupdict()["tenant_id"] else: raise ValueError("Could not find tenant ID in connection string.") - if match := re.match( - workspace_regex, connection_str, re.IGNORECASE | re.VERBOSE - ): + if match := re.match(workspace_regex, connection_str, re.IGNORECASE | re.VERBOSE): workspace_id = match.groupdict()["workspace_id"] else: raise ValueError("Could not find workspace ID in connection string.") @@ -308,19 +304,19 @@ def from_connection_string(cls, connection_str: str) -> "WorkspaceConfig": workspace_name = match.groupdict()["workspace_name"] return cls( config={ - cls.CONF_WS_ID: workspace_id, # type: ignore[dict-item] - cls.CONF_TENANT_ID: tenant_id, # type: ignore[dict-item] + cls.CONF_WS_ID: workspace_id, + cls.CONF_TENANT_ID: tenant_id, cls.CONF_WS_NAME: workspace_name, # type: ignore[dict-item] } ) @classmethod - def _read_config_values(cls, file_path: str) -> Dict[str, str]: + def _read_config_values(cls, file_path: str) -> dict[str, str]: """Read configuration file.""" if not file_path: return {} with contextlib.suppress(json.JSONDecodeError): - with open(file_path, "r", encoding="utf-8") as json_file: + with open(file_path, encoding="utf-8") as json_file: if json_file: config_ws = json.load(json_file) return { @@ -331,7 +327,7 @@ def _read_config_values(cls, file_path: str) -> Dict[str, str]: return {} @classmethod - def list_workspaces(cls) -> Dict: + def list_workspaces(cls) -> dict: """ Return list of available workspaces. @@ -418,12 +414,12 @@ def _determine_config_source(self, workspace): ) ) - def _read_pkg_config_values(self, workspace_name: Optional[str] = None): + def _read_pkg_config_values(self, workspace_name: str | None = None): """Try to find a usable config from the MSTICPy config file.""" - ws_settings = get_config("AzureSentinel", {}).get("Workspaces") # type: ignore + ws_settings = get_config("AzureSentinel", {}).get("Workspaces") if not ws_settings: return - selected_workspace: Dict[str, str] = {} + selected_workspace: dict[str, str] = {} if workspace_name: selected_workspace, self.settings_key = self._lookup_ws_name_and_id( workspace_name, ws_settings @@ -449,7 +445,7 @@ def _lookup_ws_name_and_id(self, ws_name: str, ws_configs: dict): return ws_config, name return {}, None - def _search_for_file(self, pattern: str) -> Optional[str]: + def _search_for_file(self, pattern: str) -> str | None: config_file = None for start_path in (".", ".."): searched_configs = list(Path(start_path).glob(pattern)) diff --git a/msticpy/config/__init__.py b/msticpy/config/__init__.py index 3390bbea6..ba6c29d6f 100644 --- a/msticpy/config/__init__.py +++ b/msticpy/config/__init__.py @@ -12,6 +12,7 @@ It use the ipywidgets package. """ + from ..lazy_importer import lazy_import _LAZY_IMPORTS = { diff --git a/msticpy/config/ce_azure.py b/msticpy/config/ce_azure.py index 329ce151b..7ac8d4343 100644 --- a/msticpy/config/ce_azure.py +++ b/msticpy/config/ce_azure.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Azure component edit.""" + from .._version import VERSION from .ce_simple_settings import CESimpleSettings diff --git a/msticpy/config/ce_azure_sentinel.py b/msticpy/config/ce_azure_sentinel.py index 313444778..928452a1d 100644 --- a/msticpy/config/ce_azure_sentinel.py +++ b/msticpy/config/ce_azure_sentinel.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" + from datetime import datetime import ipywidgets as widgets @@ -192,9 +193,7 @@ def _save_item(self, btn): ) self.select_item.options = self._get_select_opts() self.select_item.label = edited_provider_name - valid, status = _validate_ws( - edited_provider_name, self.mp_controls, self._COMP_PATH - ) + valid, status = _validate_ws(edited_provider_name, self.mp_controls, self._COMP_PATH) if not valid: self.set_status(status) @@ -244,9 +243,7 @@ def _resolve_settings(self, btn): workspace_name = _get_named_control(self.edit_ctrls, "WorkspaceName").value resource_group = _get_named_control(self.edit_ctrls, "ResourceGroup").value if not (workspace_id or workspace_name): - self.set_status( - "Need at least WorkspaceId or WorkspaceName to lookup settings." - ) + self.set_status("Need at least WorkspaceId or WorkspaceName to lookup settings.") return if workspace_id: self._update_settings( diff --git a/msticpy/config/ce_common.py b/msticpy/config/ce_common.py index a1768cdb5..1157ffdcc 100644 --- a/msticpy/config/ce_common.py +++ b/msticpy/config/ce_common.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Component edit utility functions.""" -from typing import Any, Dict, List, Optional, Tuple, Union + +from typing import Any import httpx import ipywidgets as widgets @@ -60,8 +61,8 @@ def print_debug(*args): # pylint: disable=too-many-return-statements -def py_to_widget( - value: Any, ctrl: Optional[widgets.Widget] = None, val_type: Optional[str] = None +def py_to_widget( # noqa: PLR0911 + value: Any, ctrl: widgets.Widget | None = None, val_type: str | None = None ) -> Any: """ Adjust type and format to suit target widget. @@ -94,11 +95,7 @@ def py_to_widget( """ if ctrl is None and val_type is None: raise ValueError("Must specify either a target control or expected val_type.") - if ( - isinstance(ctrl, widgets.Checkbox) - or val_type == "bool" - or isinstance(value, bool) - ): + if isinstance(ctrl, widgets.Checkbox) or val_type == "bool" or isinstance(value, bool): if isinstance(value, str): return value.casefold() == "true" return bool(value) @@ -117,7 +114,7 @@ def py_to_widget( return value -def widget_to_py(ctrl: Union[widgets.Widget, SettingsControl]) -> Any: +def widget_to_py(ctrl: widgets.Widget | SettingsControl) -> Any: # noqa: PLR0911 """ Adjust type and format of value returned from `ctrl.value`. @@ -172,9 +169,7 @@ def get_subscription_metadata(sub_id: str) -> dict: """ az_cloud_config = AzureCloudConfig() res_mgmt_uri = az_cloud_config.resource_manager - get_sub_url = ( - f"{res_mgmt_uri}/subscriptions/{{subscriptionid}}?api-version=2021-04-01" - ) + get_sub_url = f"{res_mgmt_uri}/subscriptions/{{subscriptionid}}?api-version=2021-04-01" headers = mp_ua_header() sub_url = get_sub_url.format(subscriptionid=sub_id) resp = httpx.get(sub_url, headers=headers) @@ -184,8 +179,7 @@ def get_subscription_metadata(sub_id: str) -> dict: return {} hdr_dict = { - item.split("=")[0]: item.split("=")[1].strip('"') - for item in www_header.split(", ") + item.split("=")[0]: item.split("=")[1].strip('"') for item in www_header.split(", ") } tenant_path = hdr_dict.get("Bearer authorization_uri", "").split("/") @@ -204,7 +198,7 @@ def get_subscription_metadata(sub_id: str) -> dict: return {"tenantId": tenant_id} -def get_def_tenant_id(sub_id: str) -> Optional[str]: +def get_def_tenant_id(sub_id: str) -> str | None: """ Get the tenant ID for a subscription. @@ -227,7 +221,7 @@ def get_def_tenant_id(sub_id: str) -> Optional[str]: return sub_metadata.get("tenantId", None) -def get_managed_tenant_id(sub_id: str) -> Optional[List[str]]: # type: ignore +def get_managed_tenant_id(sub_id: str) -> list[str] | None: """ Get the tenant IDs that are managing a subscription. @@ -247,7 +241,7 @@ def get_managed_tenant_id(sub_id: str) -> Optional[List[str]]: # type: ignore return tenant_ids if tenant_ids else None -def txt_to_dict(txt_val: str) -> Dict[str, Any]: +def txt_to_dict(txt_val: str) -> dict[str, Any]: """ Return dict from string of "key:val; key2:val2" pairs. @@ -270,12 +264,10 @@ def txt_to_dict(txt_val: str) -> Dict[str, Any]: for kv_pair in txt_val.split("\n") if kv_pair.strip() ] - return { - kval[0].strip(): kval[1].strip() if len(kval) > 1 else None for kval in kvpairs - } + return {kval[0].strip(): kval[1].strip() if len(kval) > 1 else None for kval in kvpairs} -def dict_to_txt(dict_val: Union[str, Dict[str, Any]]) -> str: +def dict_to_txt(dict_val: str | dict[str, Any]) -> str: """ Return string as "key:val; key2:val2" pairs from `dict_val`. @@ -310,7 +302,7 @@ def get_wgt_ctrl( setting_path: str, var_name: str, mp_controls: "MpConfigControls", # type: ignore - wgt_style: Optional[Dict[str, Any]] = None, + wgt_style: dict[str, Any] | None = None, instance_name: str = None, ) -> widgets.Widget: """ @@ -390,7 +382,7 @@ def get_wgt_ctrl( ctrl = widgets.Textarea( description=var_name, value=dict_to_txt(curr_val) or "", **wgt_style ) - setattr(ctrl, "tag", "txt_dict") + ctrl.tag = "txt_dict" elif st_type == "list": ctrl = widgets.Textarea( description=var_name, @@ -398,7 +390,7 @@ def get_wgt_ctrl( **(wgt_style or TEXT_AREA_LAYOUT), # tooltip="Enter each item as 'key:value'. Separate items with new lines.", ) - setattr(ctrl, "tag", "list") + ctrl.tag = "list" else: ctrl = widgets.Text( description=var_name, @@ -412,7 +404,7 @@ def get_wgt_ctrl( return ctrl -def get_defn_or_default(defn: Union[Tuple[str, Any], Any]) -> Tuple[str, Dict]: +def get_defn_or_default(defn: tuple[str, Any] | Any) -> tuple[str, dict]: """ Return the type and options (or a default) for the setting definition. @@ -437,7 +429,7 @@ def get_defn_or_default(defn: Union[Tuple[str, Any], Any]) -> Tuple[str, Dict]: def get_or_create_mpc_section( mp_controls: "MpConfigControls", # type: ignore[name-defined] section: str, - subkey: Optional[str] = None, # type: ignore + subkey: str | None = None, ) -> Any: """ Return (and create if it doesn't exist) a settings section. diff --git a/msticpy/config/ce_data_providers.py b/msticpy/config/ce_data_providers.py index 775002e87..1d7334b04 100644 --- a/msticpy/config/ce_data_providers.py +++ b/msticpy/config/ce_data_providers.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Data Providers Component Edit.""" + import re -from typing import Optional import ipywidgets as widgets @@ -63,12 +63,14 @@ def __init__(self, mp_controls: MpConfigControls): **TEXT_LAYOUT, ) super().__init__(mp_controls) - self._last_instance_path: Optional[str] = None + self._last_instance_path: str | None = None @property def _current_path(self): if self._form_current_instance_name: - return f"{self._COMP_PATH}.{self._prov_ctrl_name}-{self._form_current_instance_name}" + return ( + f"{self._COMP_PATH}.{self._prov_ctrl_name}-{self._form_current_instance_name}" + ) return f"{self._COMP_PATH}.{self._prov_ctrl_name}" @property @@ -98,13 +100,11 @@ def _form_current_instance_name(self): def _populate_edit_ctrls( self, - control_name: Optional[str] = None, + control_name: str | None = None, new_provider: bool = False, ): """Retrieve and populate form controls for the provider to display.""" - super()._populate_edit_ctrls( - control_name=control_name, new_provider=new_provider - ) + super()._populate_edit_ctrls(control_name=control_name, new_provider=new_provider) # add the instance text box self.edit_ctrls.children = [ self.text_prov_instance, @@ -120,15 +120,11 @@ def _select_provider(self, change): def _save_provider(self, btn): if self._form_current_instance_name: if not re.match(r"^[\w._:]+$", self._form_current_instance_name): - self.set_status( - "Error: instance name can only contain alphanumeric and '._:'" - ) + self.set_status("Error: instance name can only contain alphanumeric and '._:'") return # The instance name may have changed, which alters the path if self._last_instance_path != self._current_path: - self.mp_controls.rename_path( - self._last_instance_path, self._current_path - ) + self.mp_controls.rename_path(self._last_instance_path, self._current_path) super()._save_provider(btn) # refresh the item list and re-select the current item edited_provider = self._prov_name diff --git a/msticpy/config/ce_keyvault.py b/msticpy/config/ce_keyvault.py index 9c1721d38..b36a24a1f 100644 --- a/msticpy/config/ce_keyvault.py +++ b/msticpy/config/ce_keyvault.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Key Vault component edit.""" + from .._version import VERSION from .ce_simple_settings import CESimpleSettings diff --git a/msticpy/config/ce_msticpy.py b/msticpy/config/ce_msticpy.py index c620f9899..be02973ee 100644 --- a/msticpy/config/ce_msticpy.py +++ b/msticpy/config/ce_msticpy.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Key Vault component edit.""" + from .._version import VERSION from .ce_simple_settings import CESimpleSettings diff --git a/msticpy/config/ce_other_providers.py b/msticpy/config/ce_other_providers.py index 62990e770..3ce2493ce 100644 --- a/msticpy/config/ce_other_providers.py +++ b/msticpy/config/ce_other_providers.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Other Providers Component Edit.""" + from .._version import VERSION from .ce_provider_base import HELP_URIS, CEProviders @@ -20,8 +21,7 @@ class CEOtherProviders(CEProviders): # _HELP_TEXT inherited from base _HELP_URI = { "GeoIP Providers": ( - "https://msticpy.readthedocs.io/en/latest/" - + "data_acquisition/GeoIPLookups.html" + "https://msticpy.readthedocs.io/en/latest/" + "data_acquisition/GeoIPLookups.html" ), **HELP_URIS, } diff --git a/msticpy/config/ce_provider_base.py b/msticpy/config/ce_provider_base.py index f9fb66dac..800e90cbd 100644 --- a/msticpy/config/ce_provider_base.py +++ b/msticpy/config/ce_provider_base.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" + from abc import ABC -from typing import List, Optional import ipywidgets as widgets @@ -58,8 +58,7 @@ + "msticpyconfig.html#specifying-secrets-as-key-vault-secrets" ), "MSTICPy Configuration": ( - "https://msticpy.readthedocs.io/en/latest/" - + "getting_started/msticpyconfig.html" + "https://msticpy.readthedocs.io/en/latest/" + "getting_started/msticpyconfig.html" ), "Help on this tab": ( "https://msticpy.readthedocs.io/en/latest/getting_started/" @@ -137,7 +136,7 @@ def _get_select_opts(self): def _populate_edit_ctrls( self, - control_name: Optional[str] = None, + control_name: str | None = None, new_provider: bool = False, ): """Retrieve and populate form controls for the provider to display.""" @@ -145,9 +144,7 @@ def _populate_edit_ctrls( prov_name=control_name or self._prov_ctrl_name, mp_controls=self.mp_controls, conf_path=self._COMP_PATH, - prov_instance_name=( - self._select_prov_instance_name if not new_provider else "" - ), + prov_instance_name=(self._select_prov_instance_name if not new_provider else ""), ) self.edit_frame.children = [self.edit_ctrls] @@ -155,9 +152,7 @@ def _select_provider(self, change): """Update based on new selection in current providers.""" del change self._populate_edit_ctrls() - self.mp_controls.populate_ctrl_values( - f"{self._COMP_PATH}.{self.select_item.label}" - ) + self.mp_controls.populate_ctrl_values(f"{self._COMP_PATH}.{self.select_item.label}") def _add_provider(self, btn): """Add a new provider from prov_options.""" @@ -168,12 +163,8 @@ def _add_provider(self, btn): if not self.prov_options.label: self.set_status("Error: please select a provider name to add.") return - self._populate_edit_ctrls( - control_name=self.prov_options.label, new_provider=True - ) - self.mp_controls.save_ctrl_values( - f"{self._COMP_PATH}.{self.prov_options.label}" - ) + self._populate_edit_ctrls(control_name=self.prov_options.label, new_provider=True) + self.mp_controls.save_ctrl_values(f"{self._COMP_PATH}.{self.prov_options.label}") self.select_item.options = self._get_select_opts() self.select_item.label = self.prov_options.label @@ -201,7 +192,7 @@ def _save_provider(self, btn): def _get_prov_ctrls(prov_name, mp_controls, conf_path, prov_instance_name: str = None): - ctrls: List[widgets.Widget] = [] + ctrls: list[widgets.Widget] = [] if not prov_name: return widgets.VBox(ctrls, layout=CompEditDisplayMixin.no_border_layout("95%")) # prov_path = f"{conf_path}.{prov_name}" diff --git a/msticpy/config/ce_simple_settings.py b/msticpy/config/ce_simple_settings.py index cee8e2e67..b74d071b9 100644 --- a/msticpy/config/ce_simple_settings.py +++ b/msticpy/config/ce_simple_settings.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Simple settings component edit base class.""" + import ipywidgets as widgets from .._version import VERSION @@ -23,8 +24,7 @@ class CESimpleSettings(CompEditSimple): _HELP_TEXT = "" _HELP_URI = { "MSTICPy Configuration": ( - "https://msticpy.readthedocs.io/en/latest/" - + "getting_started/msticpyconfig.html" + "https://msticpy.readthedocs.io/en/latest/" + "getting_started/msticpyconfig.html" ) } @@ -33,9 +33,7 @@ def __init__(self, mp_controls: MpConfigControls): super().__init__(description=self._DESCRIPTION) self.mp_controls = mp_controls - self.comp_defn = self._get_settings_path( - mp_controls.config_defn, self._COMP_PATH - ) + self.comp_defn = self._get_settings_path(mp_controls.config_defn, self._COMP_PATH) self.settings = self._get_settings_path(mp_controls.mp_config, self._COMP_PATH) self.help.set_help(self._HELP_TEXT, self._HELP_URI) diff --git a/msticpy/config/ce_ti_providers.py b/msticpy/config/ce_ti_providers.py index 0ff9e8875..f23e22a7d 100644 --- a/msticpy/config/ce_ti_providers.py +++ b/msticpy/config/ce_ti_providers.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """TI Providers Component Edit.""" + from .._version import VERSION from .ce_provider_base import HELP_URIS, CEProviders diff --git a/msticpy/config/ce_user_defaults.py b/msticpy/config/ce_user_defaults.py index dcfd457fd..c46e68c5b 100644 --- a/msticpy/config/ce_user_defaults.py +++ b/msticpy/config/ce_user_defaults.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" + import ipywidgets as widgets from .._version import VERSION @@ -68,7 +69,7 @@ def __init__(self, mp_controls: MpConfigControls): """ super().__init__(mp_controls) # pylint: disable=import-outside-toplevel - from ..data.core.query_defns import DataEnvironment + from ..data.core.query_defns import DataEnvironment # noqa: PLC0415 self._data_env_enum = DataEnvironment @@ -182,12 +183,11 @@ def _get_settings_ctrls(self, prov_name, conf_path): curr_val = self.mp_controls.get_value(setting_path) if curr_val is None: curr_val = self._get_default_values(prov_name, conf_path) + elif "." in prov_name: + prov, child = prov_name.split(".", maxsplit=1) + curr_val = {prov: {child: curr_val}} else: - if "." in prov_name: - prov, child = prov_name.split(".", maxsplit=1) - curr_val = {prov: {child: curr_val}} - else: - curr_val = {prov_name: curr_val} + curr_val = {prov_name: curr_val} prov_ctrl.value = curr_val @@ -276,9 +276,7 @@ def _get_settings_ctrls(self, prov_name, conf_path): setting_path = f"{conf_path}.{prov_name}" prov_ctrl = self.mp_controls.get_control(setting_path) if not isinstance(prov_ctrl, UserDefLoadComponent): - prov_ctrl = UserDefLoadComponent( - self.mp_controls, prov_name, self._COMP_PATH - ) + prov_ctrl = UserDefLoadComponent(self.mp_controls, prov_name, self._COMP_PATH) self.mp_controls.set_control(setting_path, prov_ctrl) curr_val = self.mp_controls.get_value(setting_path) diff --git a/msticpy/config/comp_edit.py b/msticpy/config/comp_edit.py index 48dbf920e..121d9c8c9 100644 --- a/msticpy/config/comp_edit.py +++ b/msticpy/config/comp_edit.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Component Edit base and mixin classes.""" + from abc import ABC, abstractmethod from time import sleep -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import ipywidgets as widgets from IPython.display import display @@ -82,13 +83,12 @@ class CompEditHelp: _DEFAULT_URI = { "MSTICPy Config": ( - "https://msticpy.readthedocs.io/en/latest/" - + "getting_started/msticpyconfig.html" + "https://msticpy.readthedocs.io/en/latest/" + "getting_started/msticpyconfig.html" ) } _HELP_STYLE = "color: blue; text-decoration: underline;" - def __init__(self, help_text: str = "", help_uri: Dict[str, str] = None): + def __init__(self, help_text: str = "", help_uri: dict[str, str] = None): """ Create help sub-component. @@ -106,7 +106,7 @@ def __init__(self, help_text: str = "", help_uri: Dict[str, str] = None): self.accdn_help.selected_index = None self.set_help(help_text, help_uri) - def set_help(self, help_text: str = "", help_uri: Dict[str, str] = None): + def set_help(self, help_text: str = "", help_uri: dict[str, str] = None): """Set the help string (HTML) and URIs.""" if not help_uri: help_uri = self._DEFAULT_URI @@ -173,9 +173,7 @@ class CompEditItems(CompEditFrame): def __init__(self, description: str): """Initialize the class. Set a label with `description` as content.""" super().__init__(description=description) - self.select_item = widgets.Select( - layout=widgets.Layout(height="200px", width="99%") - ) + self.select_item = widgets.Select(layout=widgets.Layout(height="200px", width="99%")) self.edit_frame = widgets.VBox(layout=self.border_layout("99%")) self.edit_buttons = CompEditItemButtons() self.items_frame = widgets.VBox( @@ -218,8 +216,7 @@ class CEItemsBase(CompEditItems, ABC): _HELP_TEXT = """""" _HELP_URI = { "MSTICPy Configuration": ( - "https://msticpy.readthedocs.io/en/latest/" - + "getting_started/msticpyconfig.html" + "https://msticpy.readthedocs.io/en/latest/" + "getting_started/msticpyconfig.html" ) } @@ -241,21 +238,21 @@ class SettingsControl(ABC): @property @abstractmethod - def value(self) -> Union[str, Dict[str, Optional[str]]]: + def value(self) -> str | dict[str, str | None]: """Return the current value of the control.""" - @value.setter - def value(self, value: Union[str, Dict[str, Optional[str]]]): + @value.setter # noqa: B027 + def value(self, value: str | dict[str, str | None]): """Set value of controls from dict.""" -CETabControlDef = Tuple[type, Union[List[Any], Dict[str, Any]]] +CETabControlDef = tuple[type, list[Any] | dict[str, Any]] class CompEditTabs: """Tab class.""" - def __init__(self, tabs: Optional[Dict[str, CETabControlDef]] = None): + def __init__(self, tabs: dict[str, CETabControlDef] | None = None): """ Initialize the CompEditTabs class. @@ -270,10 +267,10 @@ def __init__(self, tabs: Optional[Dict[str, CETabControlDef]] = None): self.tab = widgets.Tab() self.layout = self.tab tabs = tabs or {} - self._tab_state: List[widgets.Widget] = [] - self._tab_lazy_load: Dict[int, CETabControlDef] = {} - self._tab_names: List[str] = [] - self.controls: Dict[str, Any] = {} + self._tab_state: list[widgets.Widget] = [] + self._tab_lazy_load: dict[int, CETabControlDef] = {} + self._tab_names: list[str] = [] + self.controls: dict[str, Any] = {} if tabs: for tab_name, tab_ctrl in tabs.items(): if isinstance(tab_ctrl, CEItemsBase): @@ -330,7 +327,7 @@ def _add_lazy_tab(self, tab_name: str, control_def: CETabControlDef): self.tab.children = self._tab_state self.tab.set_title(new_idx, tab_name) - def set_tab(self, tab_name: Optional[str], index: int = 0): + def set_tab(self, tab_name: str | None, index: int = 0): """Programatically set the tab by name or index.""" if tab_name: tab_index = [ @@ -344,11 +341,11 @@ def set_tab(self, tab_name: Optional[str], index: int = 0): self.tab.selected_index = index @property - def tab_names(self) -> List[str]: + def tab_names(self) -> list[str]: """Return a list of current tabs.""" return self._tab_names @property - def tab_controls(self) -> Dict[str, Any]: + def tab_controls(self) -> dict[str, Any]: """Return a list of current tab names and controls.""" return self.controls diff --git a/msticpy/config/compound_ctrls.py b/msticpy/config/compound_ctrls.py index 071a9b751..7e41e802c 100644 --- a/msticpy/config/compound_ctrls.py +++ b/msticpy/config/compound_ctrls.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Compound control classes.""" + import os from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import ipywidgets as widgets @@ -37,7 +38,7 @@ class ArgControl(SettingsControl, CompEditStatusMixin): def __init__( self, - setting_path: Optional[str], + setting_path: str | None, name: str, store_type: str = STORE_TEXT, item_value: Any = None, @@ -63,9 +64,7 @@ def __init__( self.name = name self.kv_client = None - self.lbl_setting = widgets.Label( - value=self.name, layout=widgets.Layout(width="130px") - ) + self.lbl_setting = widgets.Label(value=self.name, layout=widgets.Layout(width="130px")) self.rb_store_type = widgets.RadioButtons( options=[STORE_TEXT, STORE_ENV_VAR, STORE_KEYVAULT], description="Storage:", @@ -108,7 +107,7 @@ def __init__( self.rb_store_type.observe(self._change_store, names="value") @property - def value(self) -> Union[str, Dict[str, Optional[str]]]: + def value(self) -> str | dict[str, str | None]: """ Return the value of the control. @@ -125,7 +124,7 @@ def value(self) -> Union[str, Dict[str, Optional[str]]]: return {self.rb_store_type.value: widget_to_py(self.txt_val)} @value.setter - def value(self, value: Union[str, Dict[str, Optional[str]]]): + def value(self, value: str | dict[str, str | None]): """ Set control to value. @@ -218,7 +217,7 @@ def _set_kv_secret(self, btn): def _get_args_val(arg_setting): """Return a dict whether the value is a str or a dict.""" _, arg_val = next(iter(arg_setting.items())) - if isinstance(arg_val, (str, int, bool)): + if isinstance(arg_val, str | int | bool): return STORE_TEXT, arg_val return next(iter(arg_val.items())) @@ -254,7 +253,7 @@ def _set_kv_secret_value( item_name: str, value: str, kv_client: Any = None, - ) -> Tuple[bool, str, Any]: + ) -> tuple[bool, str, Any]: """ Set the Key Vault secret to `value`. @@ -295,7 +294,7 @@ def _set_kv_secret_value( item_name: str, value: str, kv_client: Any = None, - ) -> Tuple[bool, str, Any]: + ) -> tuple[bool, str, Any]: """Return empty response function if Key Vault cannot be initialized.""" del setting_path, item_name, value, kv_client return False, "Azure keyvault libraries are not installed", None @@ -344,7 +343,7 @@ def _set_prov_name(self, prov_name): ) @property - def value(self) -> Union[str, Dict[str, Optional[str]]]: + def value(self) -> str | dict[str, str | None]: """ Return the current value of the control. @@ -357,13 +356,11 @@ def value(self) -> Union[str, Dict[str, Optional[str]]]: """ alias = {"alias": self.txt_alias.value} if self.txt_alias.value else {} - connect = ( - {"connect": self.cb_connect.value} if not self.cb_connect.value else {} - ) + connect = {"connect": self.cb_connect.value} if not self.cb_connect.value else {} return {**alias, **connect} @value.setter - def value(self, value: Union[str, Dict[str, Optional[str]]]): + def value(self, value: str | dict[str, str | None]): """ Set the value of the component from settings. @@ -385,9 +382,7 @@ class UserDefLoadComponent(SettingsControl): _W_STYLE = {"description_width": "100px"} # pylint: disable=line-too-long - def __init__( - self, mp_controls: MpConfigControls, comp_name: str, setting_path: str - ): + def __init__(self, mp_controls: MpConfigControls, comp_name: str, setting_path: str): """ Initialize the control. @@ -435,15 +430,11 @@ def _create_controls(self, path, mp_controls): ctrl_path = f"{path}.{name}" if isinstance(settings, str): # Simple case of a string value - self.controls[name] = widgets.Text( - description="Value", value=curr_value or "" - ) + self.controls[name] = widgets.Text(description="Value", value=curr_value or "") self._add_control_to_map(ctrl_path, self.controls[name]) if isinstance(settings, tuple): # if tuple then the second elem of the tuple is the type defn - self.controls[name] = self._create_select_ctrl( - settings, name, curr_value - ) + self.controls[name] = self._create_select_ctrl(settings, name, curr_value) self._add_control_to_map(ctrl_path, self.controls[name]) elif isinstance(settings, dict): self.controls[name] = widgets.Text(value=name, disabled=True) @@ -460,7 +451,7 @@ def _add_control_to_map(self, path, ctrl): ctrl_map = ctrl_map.get(elem) @property - def value(self) -> Union[str, Dict[str, Optional[str]]]: + def value(self) -> str | dict[str, str | None]: """ Return the current value of the control. @@ -473,7 +464,7 @@ def value(self) -> Union[str, Dict[str, Optional[str]]]: return self._get_val_from_ctrl(self.control_map) @value.setter - def value(self, value: Union[str, Dict[str, Optional[str]]]): + def value(self, value: str | dict[str, str | None]): """Set value of controls from dict.""" if isinstance(value, dict): self._set_ctrl_from_val(path="", value=value) @@ -501,9 +492,7 @@ def _set_ctrl_from_val(self, path, value): sub_path = f"{path}.{key}" if path else key if isinstance(val, dict): if isinstance(self.controls[key], widgets.Textarea): - self.controls[key].value = py_to_widget( - val, ctrl=self.controls[key] - ) + self.controls[key].value = py_to_widget(val, ctrl=self.controls[key]) else: self._set_ctrl_from_val(sub_path, val) elif key in self.controls: @@ -545,7 +534,7 @@ def _create_select_ctrl(self, ctrl_defn, name, curr_value): description=name, style=self._W_STYLE, ) - setattr(wgt, "tag", "txt_dict") + wgt.tag = "txt_dict" wgt.value = py_to_widget(curr_value, ctrl=wgt) or "" return wgt raise TypeError(f"Unknown definition type {val_type} for {name}/{ctrl_defn}") diff --git a/msticpy/config/file_browser.py b/msticpy/config/file_browser.py index d3e2c1cc0..ae7092a6a 100644 --- a/msticpy/config/file_browser.py +++ b/msticpy/config/file_browser.py @@ -6,8 +6,9 @@ """File Browser class.""" import contextlib +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any import ipywidgets as widgets @@ -18,7 +19,7 @@ __author__ = "Ian Hellen" -StrOrPath = Union[str, Path] +StrOrPath = str | Path # pylint: disable=too-many-instance-attributes @@ -42,7 +43,7 @@ def __init__(self, path: StrOrPath = ".", select_cb: Callable[[str], Any] = None """ self.current_folder = Path(path).resolve() - self.file: Optional[str] = None + self.file: str | None = None self.action = select_cb file_layout = widgets.Layout(height="200px", width="45%") @@ -119,9 +120,7 @@ def _open_folder(self, btn=None, tgt_folder=None): if tgt_folder == self.PARENT: tgt_folder = self.current_folder.parent if tgt_folder: - self.current_folder = ( - Path(self.current_folder).joinpath(tgt_folder).resolve() - ) + self.current_folder = Path(self.current_folder).joinpath(tgt_folder).resolve() self.txt_path.value = str(self.current_folder) folders, files = self.read_folder(self.current_folder) self.select_folder.options = self.get_folder_list(folders) @@ -143,7 +142,7 @@ def _return_file(self, btn): self.action(self.file) @staticmethod - def read_folder(folder: StrOrPath) -> Tuple[List[StrOrPath], List[StrOrPath]]: + def read_folder(folder: StrOrPath) -> tuple[list[StrOrPath], list[StrOrPath]]: """ Return folder contents. @@ -169,7 +168,7 @@ def read_folder(folder: StrOrPath) -> Tuple[List[StrOrPath], List[StrOrPath]]: folders.append(file) return folders, files # type: ignore[return-value] - def get_folder_list(self, folders: List[StrOrPath]) -> List[StrOrPath]: + def get_folder_list(self, folders: list[StrOrPath]) -> list[StrOrPath]: """Return sorted list of folders with '..' inserted if not root.""" if self.current_folder != Path(self.current_folder.parts[0]): return [self.PARENT, *(sorted(folders))] @@ -179,10 +178,8 @@ def _search(self, btn): """Handle event for search button.""" del btn if self.txt_search.value: - found_files: Optional[List[Path]] = None + found_files: list[Path] | None = None while found_files is None: with contextlib.suppress(FileNotFoundError): found_files = list(self.current_folder.rglob(self.txt_search.value)) - self.select_search.options = [ - str(file) for file in found_files if file.exists() - ] + self.select_search.options = [str(file) for file in found_files if file.exists()] diff --git a/msticpy/config/mp_config_control.py b/msticpy/config/mp_config_control.py index 28bfd391f..243c8a2b8 100644 --- a/msticpy/config/mp_config_control.py +++ b/msticpy/config/mp_config_control.py @@ -4,10 +4,11 @@ # license information. # -------------------------------------------------------------------------- """MP Config Control Class.""" + import pkgutil import re from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import ipywidgets as widgets import yaml @@ -33,7 +34,7 @@ class MpConfigControls: """Msticpy configuration and settings database.""" - def __init__(self, mp_config_def: Dict[str, Any], mp_config: Dict[str, Any]): + def __init__(self, mp_config_def: dict[str, Any], mp_config: dict[str, Any]): """ Return an instance of MpConfigControls. @@ -56,7 +57,7 @@ def __init__(self, mp_config_def: Dict[str, Any], mp_config: Dict[str, Any]): self.mp_config["DataProviders"]["AzureCLI"] = self.mp_config["AzureCLI"] @staticmethod - def _get_elem_from_path(path, member_dict: Dict[str, Any]): + def _get_elem_from_path(path, member_dict: dict[str, Any]): """Return an item at the path from `member_dict`.""" paths = path.split(".") current_elem = member_dict @@ -68,7 +69,7 @@ def _get_elem_from_path(path, member_dict: Dict[str, Any]): return current_elem def _set_elem_at_path( - self, path: str, member_dict: Dict[str, Any], value: Any, create: bool = True + self, path: str, member_dict: dict[str, Any], value: Any, create: bool = True ): """Set item at the path from `member_dict` to `value`.""" path_elems = path.rsplit(".", maxsplit=1) @@ -95,7 +96,7 @@ def _set_elem_at_path( current_elem[tgt_key] = value print_debug("set", parent_path, tgt_key, value) - def _del_elem_at_path(self, path: str, member_dict: Dict[str, Any]): + def _del_elem_at_path(self, path: str, member_dict: dict[str, Any]): """Delete an item at `path`.""" parent_path, tgt_key = path.rsplit(".", maxsplit=1) parent = self._get_elem_from_path(parent_path, member_dict) @@ -129,15 +130,13 @@ def del_control(self, path: str): """Delete the control stored at `path`.""" self._del_elem_at_path(path, self.controls) - def get_defn(self, path: str) -> Union[Dict[str, Any], Tuple[str, Any]]: + def get_defn(self, path: str) -> dict[str, Any] | tuple[str, Any]: """Return the setting definition at `path`.""" defn = self._get_elem_from_path(path, self.config_defn) if defn is not None: return defn if path.startswith("AzureSentinel.Workspaces"): - path = re.sub( - r"(?PAzureSentinel\.Workspaces\.)([^.]+)", r"\1Default", path - ) + path = re.sub(r"(?PAzureSentinel\.Workspaces\.)([^.]+)", r"\1Default", path) return self._get_elem_from_path(path, self.config_defn) def rename_path(self, old_path: str, new_path: str): @@ -149,9 +148,7 @@ def rename_path(self, old_path: str, new_path: str): or len(old_path_elems) == 1 or len(new_path_elems) == 1 ): - raise ValueError( - "Can only rename the bottom element of paths", old_path, new_path - ) + raise ValueError("Can only rename the bottom element of paths", old_path, new_path) path_root = old_path_elems[0] src_key = old_path_elems[1] tgt_key = new_path_elems[1] @@ -199,9 +196,9 @@ def _get_ctrl_values(self, path: str): print_debug( type(ctrl_tree), "instance check", - isinstance(ctrl_tree, (widgets.Widget, SettingsControl)), + isinstance(ctrl_tree, widgets.Widget | SettingsControl), ) - if isinstance(ctrl_tree, (widgets.Widget, SettingsControl)): + if isinstance(ctrl_tree, widgets.Widget | SettingsControl): return widget_to_py(ctrl_tree) if isinstance(ctrl_tree, dict): return {key: self._get_ctrl_values(f"{path}.{key}") for key in ctrl_tree} @@ -221,7 +218,7 @@ def _create_ctrl_dict(self, config_dict): ctrl_dict[name] = None return ctrl_dict - def validate_all_settings(self, show_all: bool = False) -> List[ValidationResult]: + def validate_all_settings(self, show_all: bool = False) -> list[ValidationResult]: """ Validate settings against definitions. @@ -244,8 +241,8 @@ def validate_all_settings(self, show_all: bool = False) -> List[ValidationResult return results def validate_setting( - self, path: str, defn_path: Optional[str] = None, show_all: bool = False - ) -> List[ValidationResult]: + self, path: str, defn_path: str | None = None, show_all: bool = False + ) -> list[ValidationResult]: """ Validate settings against definitions for a specific path. @@ -276,7 +273,7 @@ def validate_setting( return [res for res in up_results if not res[0] or show_all] return [ValidationResult(True, "No validation results found")] - def _unpack_lists(self, res_list: List[Any]) -> List[ValidationResult]: + def _unpack_lists(self, res_list: list[Any]) -> list[ValidationResult]: """Unpack nested lists into a single list.""" results = [] for item in res_list: @@ -287,9 +284,9 @@ def _unpack_lists(self, res_list: List[Any]) -> List[ValidationResult]: return results # pylint: disable=too-many-return-statements - def _validate_setting_at_path( - self, path: str, defn_path: Optional[str] = None, index: Optional[int] = None - ) -> Union[ValidationResult, List[Union[ValidationResult, List[Any]]]]: + def _validate_setting_at_path( # noqa: PLR0911 + self, path: str, defn_path: str | None = None, index: int | None = None + ) -> ValidationResult | list[ValidationResult | list[Any]]: """Recursively validate settings at path.""" defn_path = defn_path or path conf_defn = self.get_defn(defn_path) @@ -354,9 +351,7 @@ def _yml_extract_type(self, conf_val): """Extract type and options from definition.""" if not conf_val or "(" not in conf_val or ")" not in conf_val: return "unknown", {} - val_type_match = re.match( - r"(?P[^()]+)\((?P.*)\)$", conf_val.strip() - ) + val_type_match = re.match(r"(?P[^()]+)\((?P.*)\)$", conf_val.strip()) val_type = val_type_match.groupdict().get("type") val_param_str = val_type_match.groupdict().get("params", "") @@ -374,8 +369,7 @@ def _yml_extract_type(self, conf_val): val_params = {} if "options" in val_params: val_params["options"] = [ - val.strip("'\"") - for val in val_params["options"].strip()[1:-1].split("; ") + val.strip("'\"") for val in val_params["options"].strip()[1:-1].split("; ") ] if "mp_defn_path" in val_params: defn_path = val_params.pop("mp_defn_path").strip(" /\"'").replace("/", ".") @@ -428,7 +422,7 @@ def _convert_mp_config_list(self, mp_conf_list): return out_list -def get_mpconfig_definitions() -> Dict[str, Any]: +def get_mpconfig_definitions() -> dict[str, Any]: """ Return the current msticpyconfig definition dictionary. @@ -499,7 +493,7 @@ def _validate_m_enum(value, path, val_type, val_opts): mssg = _get_mssg(value, path) if _is_none_and_not_required(value, val_type, val_opts): return ValidationResult(True, f"{_VALID_SUCCESS} {mssg}") - if not isinstance(value, (str, list)): + if not isinstance(value, str | list): return ValidationResult( False, f"Value '{value}' of type {type(value)} should be type {val_type} - {mssg}", @@ -537,7 +531,7 @@ def _validate_txt_dict(value, path, val_type, val_opts): False, f"Key {d_key} of {value} must be a string - {mssg}", ) - if not isinstance(d_val, (str, int, bool)): + if not isinstance(d_val, str | int | bool): return ValidationResult( False, f"Value {d_val} of key {d_key} in {value} must be a" diff --git a/msticpy/config/mp_config_edit.py b/msticpy/config/mp_config_edit.py index 3fd6e01a9..96ed1d6cd 100644 --- a/msticpy/config/mp_config_edit.py +++ b/msticpy/config/mp_config_edit.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" -from typing import Any, Dict, Optional, Union, cast + +from typing import Any, cast import ipywidgets as widgets from IPython.display import display @@ -46,7 +47,7 @@ class MpConfigEdit(CompEditDisplayMixin): def __init__( self, - settings: Optional[Union[Dict[str, Any], MpConfigFile, str]] = None, + settings: dict[str, Any] | MpConfigFile | str | None = None, conf_filepath: str = None, ): """ @@ -74,10 +75,8 @@ def __init__( self._lbl_loading = widgets.Label(value="Loading. Please wait.") display(self._lbl_loading) if isinstance(settings, MpConfigFile): - self.mp_conf_file = MpConfigFile( - settings=settings.settings, file=conf_filepath - ) - elif isinstance(settings, (dict, SettingsDict)): + self.mp_conf_file = MpConfigFile(settings=settings.settings, file=conf_filepath) + elif isinstance(settings, dict | SettingsDict): self.mp_conf_file = MpConfigFile(settings=settings, file=conf_filepath) elif isinstance(settings, str): self.mp_conf_file = MpConfigFile(file=settings) @@ -85,13 +84,13 @@ def __init__( # This is the default if neither settings nor conf_filepath are passed. self.mp_conf_file = MpConfigFile(file=conf_filepath) self.mp_conf_file.load_default() - self.tool_buttons: Dict[str, widgets.Widget] = {} + self.tool_buttons: dict[str, widgets.Widget] = {} self._inc_loading_label() # Get the settings definitions and Config controls object mp_def_dict = get_mpconfig_definitions() self.mp_controls = MpConfigControls( - mp_def_dict, cast(Dict[str, Any], self.mp_conf_file.settings) + mp_def_dict, cast(dict[str, Any], self.mp_conf_file.settings) ) self._inc_loading_label() @@ -141,7 +140,7 @@ def controls(self): """Return a list of current tab names and controls.""" return self.tab_ctrl.tab_controls - def set_tab(self, tab_name: Optional[str], index: int = 0): + def set_tab(self, tab_name: str | None, index: int = 0): """Programmatically set the tab by name or index.""" self.tab_ctrl.set_tab(tab_name, index) @@ -181,12 +180,9 @@ def _create_data_tabs(self): # Set these controls as named attributes on the object setattr(self, name.replace(" ", "_"), ctrl) - def _get_tab_definitions(self) -> Dict[str, CETabControlDef]: + def _get_tab_definitions(self) -> dict[str, CETabControlDef]: """Return tab definitions and arguments.""" - return { - name: (cls, [self.mp_controls]) - for name, cls in self._TAB_DEFINITIONS.items() - } + return {name: (cls, [self.mp_controls]) for name, cls in self._TAB_DEFINITIONS.items()} @property def current_config_file(self): diff --git a/msticpy/config/mp_config_file.py b/msticpy/config/mp_config_file.py index 432f7a989..11d6660ff 100644 --- a/msticpy/config/mp_config_file.py +++ b/msticpy/config/mp_config_file.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Msticpy Config class.""" + from __future__ import annotations import io @@ -11,7 +12,7 @@ from contextlib import redirect_stdout, suppress from datetime import datetime from pathlib import Path -from typing import Any, Dict, Union +from typing import Any import ipywidgets as widgets import yaml @@ -70,8 +71,8 @@ class MpConfigFile(CompEditStatusMixin, CompEditDisplayMixin): def __init__( self, - file: Union[str, Path, None] = None, - settings: Union[Dict[str, Any], SettingsDict, None] = None, + file: str | Path | None = None, + settings: dict[str, Any] | SettingsDict | None = None, ): """ Create an instance of the MSTICPy Configuration helper class. @@ -101,24 +102,18 @@ def __init__( # Set up controls self.file_browser = FileBrowser(select_cb=self.load_from_file) - self.txt_viewer = widgets.Textarea( - layout=widgets.Layout(width="99%", height="300px") - ) + self.txt_viewer = widgets.Textarea(layout=widgets.Layout(width="99%", height="300px")) self.btn_close = widgets.Button(description="Close viewer") self.btn_close.on_click(self._close_view) self.html_title = widgets.HTML("

MSTICPy settings

") - self.txt_current_config_path = widgets.Text( - description="Current file", **_TXT_STYLE - ) + self.txt_current_config_path = widgets.Text(description="Current file", **_TXT_STYLE) self.txt_default_config_path = widgets.Text( description="Default Config path", **_TXT_STYLE ) - self._txt_import_url = widgets.Text( - description="MS Sentinel Portal URL", **_TXT_STYLE - ) - self._last_workspace: Dict[str, Dict[str, str]] - self.buttons: Dict[str, widgets.Button] = {} + self._txt_import_url = widgets.Text(description="MS Sentinel Portal URL", **_TXT_STYLE) + self._last_workspace: dict[str, dict[str, str]] + self.buttons: dict[str, widgets.Button] = {} self.btn_pane = self._setup_buttons() self.info_pane = widgets.VBox( [ @@ -149,7 +144,7 @@ def __init__( # set the default location even if user supplied file parameter self.mp_config_def_path = current_config_path() or self.current_file - if settings is not None and isinstance(settings, (dict, SettingsDict)): + if settings is not None and isinstance(settings, dict | SettingsDict): # If caller has supplied settings, we don't want to load # anything from a file self.settings = SettingsDict(settings) @@ -168,7 +163,7 @@ def current_file(self): return self.txt_current_config_path.value @current_file.setter - def current_file(self, file_name: Union[str, Path]): + def current_file(self, file_name: str | Path): """Set currently loaded file path.""" self.txt_current_config_path.value = str(file_name) @@ -178,7 +173,7 @@ def default_config_file(self): return self.txt_default_config_path.value @default_config_file.setter - def default_config_file(self, file_name: Union[str, Path]): + def default_config_file(self, file_name: str | Path): """Set default msticpyconfig path.""" self.txt_default_config_path.value = file_name @@ -195,7 +190,7 @@ def browse_for_file(self, show: bool = True): if show: display(self.viewer) - def load_from_file(self, file: Union[str, Path]): + def load_from_file(self, file: str | Path): """Load settings from `file`.""" self.settings = self._read_mp_config(file) self.current_file = file @@ -235,9 +230,7 @@ def save_to_file(self, file: str, backup: bool = True): """ # remove empty settings sections before saving - empty_items = [ - section for section, settings in self.settings.items() if not settings - ] + empty_items = [section for section, settings in self.settings.items() if not settings] for empty_section in empty_items: del self.settings[empty_section] # create a backup, if required @@ -292,7 +285,7 @@ def show_kv_secrets(self, show: bool = True): display(self.viewer) @staticmethod - def get_workspace_from_url(url: str) -> Dict[str, Dict[str, str]]: + def get_workspace_from_url(url: str) -> dict[str, dict[str, str]]: """ Return workspace settings from Sentinel portal URL. @@ -320,8 +313,7 @@ def _show_sentinel_workspace(self, show: bool = True): self.txt_viewer.value = "\n".join( [ workspace_settings, - "\n" - "Use 'Import into settings' button to import into current settings.", + "\nUse 'Import into settings' button to import into current settings.", ] ) self.viewer.children = [self.txt_viewer, self.btn_close] @@ -340,7 +332,7 @@ def _import_sentinel_settings(self): def _read_mp_config(self, file): if Path(file).is_file(): - with open(file, "r", encoding="utf-8") as mp_hdl: + with open(file, encoding="utf-8") as mp_hdl: try: return SettingsDict(yaml.safe_load(mp_hdl)) except yaml.scanner.ScannerError as err: @@ -441,39 +433,25 @@ def _btn_exec(*args): def _setup_buttons(self): btn_style = {"layout": widgets.Layout(width="200px")} - self.buttons["load"] = widgets.Button( - **(self._BUTTON_DEFS["load"]), **btn_style - ) + self.buttons["load"] = widgets.Button(**(self._BUTTON_DEFS["load"]), **btn_style) self.buttons["load_def"] = widgets.Button( **(self._BUTTON_DEFS["load_def"]), **btn_style ) - self.buttons["reload"] = widgets.Button( - **(self._BUTTON_DEFS["reload"]), **btn_style - ) - self.buttons["view"] = widgets.Button( - **(self._BUTTON_DEFS["view"]), **btn_style - ) + self.buttons["reload"] = widgets.Button(**(self._BUTTON_DEFS["reload"]), **btn_style) + self.buttons["view"] = widgets.Button(**(self._BUTTON_DEFS["view"]), **btn_style) self.buttons["validate"] = widgets.Button( **(self._BUTTON_DEFS["validate"]), **btn_style ) - self.buttons["convert"] = widgets.Button( - **(self._BUTTON_DEFS["convert"]), **btn_style - ) - self.buttons["save"] = widgets.Button( - **(self._BUTTON_DEFS["save"]), **btn_style - ) - self.buttons["showkv"] = widgets.Button( - **(self._BUTTON_DEFS["showkv"]), **btn_style - ) + self.buttons["convert"] = widgets.Button(**(self._BUTTON_DEFS["convert"]), **btn_style) + self.buttons["save"] = widgets.Button(**(self._BUTTON_DEFS["save"]), **btn_style) + self.buttons["showkv"] = widgets.Button(**(self._BUTTON_DEFS["showkv"]), **btn_style) self.buttons["get_workspace"] = widgets.Button( **(self._BUTTON_DEFS["get_workspace"]), **btn_style ) self.buttons["import_workspace"] = widgets.Button( **(self._BUTTON_DEFS["import_workspace"]), **btn_style ) - self._btn_view_setting = widgets.Button( - description="Get Workspace", **btn_style - ) + self._btn_view_setting = widgets.Button(description="Get Workspace", **btn_style) self._btn_import_settings = widgets.Button( description="Import into settings", disabled=True, **btn_style ) @@ -486,12 +464,8 @@ def _setup_buttons(self): self.buttons["save"].on_click(self._save_file) self.buttons["reload"].on_click(self._btn_func("refresh_mp_config")) self.buttons["showkv"].on_click(self._btn_func_no_disp("show_kv_secrets")) - self.buttons["get_workspace"].on_click( - self._btn_func("_show_sentinel_workspace") - ) - self.buttons["import_workspace"].on_click( - self._btn_func("_import_sentinel_settings") - ) + self.buttons["get_workspace"].on_click(self._btn_func("_show_sentinel_workspace")) + self.buttons["import_workspace"].on_click(self._btn_func("_import_sentinel_settings")) btns1 = widgets.VBox(list(self.buttons.values())[: len(self.buttons) // 2]) # flake8: noqa: E203 diff --git a/msticpy/config/query_editor.py b/msticpy/config/query_editor.py index 6007c40d6..8fedca037 100644 --- a/msticpy/config/query_editor.py +++ b/msticpy/config/query_editor.py @@ -5,12 +5,14 @@ # license information. # -------------------------------------------------------------------------- """Query Editor.""" + from __future__ import annotations +from collections.abc import Callable from dataclasses import asdict, dataclass from pathlib import Path from types import TracebackType -from typing import Any, Callable, Literal, cast +from typing import Any, Literal, cast import ipywidgets as widgets import yaml @@ -174,9 +176,7 @@ class QueryParameterEditWidget(IPyDisplayMixin): """ - def __init__( - self: QueryParameterEditWidget, container: Query | QueryDefaults - ) -> None: + def __init__(self: QueryParameterEditWidget, container: Query | QueryDefaults) -> None: """Initialize the class.""" self._changed_data: bool = False self.param_container: Query | QueryDefaults = container @@ -272,9 +272,7 @@ def set_param_container(self: Self, container: Query | QueryDefaults) -> None: """Set the parameter container.""" self.param_container = container if self.param_container and self.param_container.parameters: - self.parameter_dropdown.options = list( - self.param_container.parameters.keys() - ) + self.parameter_dropdown.options = list(self.param_container.parameters.keys()) init_change = CustomChange(new=next(iter(self.param_container.parameters))) self.populate_widgets(init_change) else: @@ -430,12 +428,8 @@ def __init__(self: QueryEditWidget, query_collection: QueryCollection) -> None: self.query_opts_widget.set_title(1, "Query metadata") self.query_opts_widget.selected_index = None self.add_query_button: widgets.Button = widgets.Button(description="New Query") - self.save_query_button: widgets.Button = widgets.Button( - description="Save Query" - ) - self.delete_query_button: widgets.Button = widgets.Button( - description="Delete Query" - ) + self.save_query_button: widgets.Button = widgets.Button(description="Save Query") + self.delete_query_button: widgets.Button = widgets.Button(description="Delete Query") self.queries_widget: widgets.VBox = widgets.VBox( children=[ widgets.Label(value="Query"), @@ -569,9 +563,7 @@ def _fmt_query(query: str) -> str: The formatted query string. """ - return "\n|".join( - line.strip() for line in query.strip().split("|") if line.strip() - ) + return "\n|".join(line.strip() for line in query.strip().split("|") if line.strip()) def add_query(self: Self, button: widgets.Button) -> None: """ @@ -639,9 +631,7 @@ def delete_query(self: Self, button: widgets.Button) -> None: class MetadataEditWidget(IPyDisplayMixin): """A class for editing Metadata properties.""" - def __init__( - self: MetadataEditWidget, metadata: QueryMetadata | None = None - ) -> None: + def __init__(self: MetadataEditWidget, metadata: QueryMetadata | None = None) -> None: """ Initialize a MetadataEditWidget object. @@ -700,9 +690,7 @@ def __init__( placeholder="(optional)", **txt_fmt(), ) - self.save_metadata_widget: widgets.Button = widgets.Button( - description="Save metadata" - ) + self.save_metadata_widget: widgets.Button = widgets.Button(description="Save metadata") self.save_metadata_widget.on_click(self.save_metadata) self.layout: widgets.VBox = widgets.VBox( @@ -747,22 +735,16 @@ def populate_widgets(self: Self) -> None: self.version_widget.value = self.metadata.version or "" self.description_widget.value = self.metadata.description or "" self.data_env_widget.value = ( - tuple(self.metadata.data_environments) - if self.metadata.data_environments - else () + tuple(self.metadata.data_environments) if self.metadata.data_environments else () ) self.data_families_widget.value = ( - ", ".join(self.metadata.data_families) - if self.metadata.data_families - else "" + ", ".join(self.metadata.data_families) if self.metadata.data_families else "" ) self.database_widget.value = self.metadata.database or "" self.cluster_widget.value = self.metadata.cluster or "" self.clusters_widget.value = "\n".join(self.metadata.clusters or []) self.cluster_groups_widget.value = "\n".join(self.metadata.cluster_groups or []) - self.tags_widget.value = ( - ", ".join(self.metadata.tags) if self.metadata.tags else "" - ) + self.tags_widget.value = ", ".join(self.metadata.tags) if self.metadata.tags else "" self.data_source_widget.value = self.metadata.data_source or "" def save_metadata(self: Self, button: widgets.Button) -> None: @@ -772,9 +754,7 @@ def save_metadata(self: Self, button: widgets.Button) -> None: self.metadata.description = self.description_widget.value self.metadata.data_environments = list(self.data_env_widget.value) self.metadata.data_families = [ - fam.strip() - for fam in self.data_families_widget.value.split(",") - if fam.strip() + fam.strip() for fam in self.data_families_widget.value.split(",") if fam.strip() ] self.metadata.database = self.database_widget.value self.metadata.cluster = self.cluster_widget.value @@ -782,8 +762,7 @@ def save_metadata(self: Self, button: widgets.Button) -> None: cluster.strip() for cluster in self.clusters_widget.value.split("\n") ] self.metadata.cluster_groups = [ - cluster_grp.strip() - for cluster_grp in self.cluster_groups_widget.value.split("\n") + cluster_grp.strip() for cluster_grp in self.cluster_groups_widget.value.split("\n") ] self.metadata.tags = [ tag.strip() for tag in self.tags_widget.value.split(",") if tag.strip() @@ -830,16 +809,12 @@ def __init__( description="Current file", layout=widgets.Layout(width="70%"), ) - if isinstance(query_file, (Path, str)): + if isinstance(query_file, Path | str): self.filename_widget.value = str(query_file) self._open_initial_file() else: - self.query_collection: QueryCollection = ( - query_file or self._new_collection() - ) - self.filename_widget.value = ( - self.query_collection.file_name or _DEF_FILENAME - ) + self.query_collection: QueryCollection = query_file or self._new_collection() + self.filename_widget.value = self.query_collection.file_name or _DEF_FILENAME self.query_editor: QueryEditWidget = QueryEditWidget(self.query_collection) self.metadata_editor: MetadataEditWidget = MetadataEditWidget( self.query_collection.metadata @@ -931,9 +906,7 @@ def _open_file(self: Self, button: widgets.Button) -> None: """Open a new query collection.""" del button if self._unsaved_changes() and not self.ignore_changes.value: - print( - "Please save or check 'Ignore changes' before opening a different file." - ) + print("Please save or check 'Ignore changes' before opening a different file.") return self._reset_change_state() self.query_collection = load_queries_from_yaml(self.current_file) @@ -995,9 +968,7 @@ def __enter__(self: Self) -> Self: def str_presenter(dumper: yaml.SafeDumper, data: str) -> yaml.ScalarNode: if "\n" in data: - data = "\n".join( - line.rstrip() for line in data.splitlines(keepends=True) - ) + data = "\n".join(line.rstrip() for line in data.splitlines(keepends=True)) return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") return dumper.represent_scalar("tag:yaml.org,2002:str", data) @@ -1037,7 +1008,7 @@ def load_queries_from_yaml(yaml_file: str | Path) -> QueryCollection: A QueryCollection object containing the loaded queries. """ - with open(yaml_file, "r", encoding="utf-8") as f_handle: + with open(yaml_file, encoding="utf-8") as f_handle: yaml_data: dict[str, Any] = yaml.safe_load(f_handle) metadata = QueryMetadata(**yaml_data.get("metadata", {})) @@ -1072,22 +1043,15 @@ def save_queries_to_yaml( del query_dict["file_name"] _rename_data_type(query_dict) with YamlLiteralBlockContext(): - yaml_data: str = yaml.safe_dump( - _remove_none_values(query_dict), sort_keys=False - ) + yaml_data: str = yaml.safe_dump(_remove_none_values(query_dict), sort_keys=False) Path(yaml_file).write_text(yaml_data, encoding="utf-8") def _create_query_defaults(defaults: dict[str, Any]) -> QueryDefaults: """Create a QueryDefaults object.""" - def_metadata: dict[str, Any] = ( - defaults["metadata"] if "metadata" in defaults else {} - ) + def_metadata: dict[str, Any] = defaults["metadata"] if "metadata" in defaults else {} def_params: dict[str, QueryParameter] = ( - { - name: _create_parameter(param) - for name, param in defaults["parameters"].items() - } + {name: _create_parameter(param) for name, param in defaults["parameters"].items()} if "parameters" in defaults and defaults["parameters"] else {} ) @@ -1098,9 +1062,7 @@ def _create_query(query_data: dict[str, Any]) -> Query: """Create a Query object.""" parameters: dict[str, Any] = query_data.get("parameters", {}) if parameters: - parameters = { - name: _create_parameter(param) for name, param in parameters.items() - } + parameters = {name: _create_parameter(param) for name, param in parameters.items()} return Query( description=query_data.get("description", ""), metadata=query_data.get("metadata", {}), @@ -1119,14 +1081,14 @@ def _create_parameter(param_data: dict[str, Any]) -> QueryParameter: def _remove_none_values( - source_obj: dict[str, Any] | list[Any] | tuple[Any, ...] + source_obj: dict[str, Any] | list[Any] | tuple[Any, ...], ) -> dict[str, Any] | list[Any] | tuple[Any, ...]: """Recursively remove any item with a None value from a nested dictionary.""" if isinstance(source_obj, dict): return { key: _remove_none_values(val) for key, val in source_obj.items() - if val is not None or (isinstance(val, (list, dict)) and len(val) > 0) + if val is not None or (isinstance(val, list | dict) and len(val) > 0) } if isinstance(source_obj, list): return [_remove_none_values(val) for val in source_obj if val is not None] @@ -1144,6 +1106,6 @@ def _rename_data_type(source_obj: dict[str, Any] | list[Any] | tuple[Any]) -> No source_obj["type"] = val for value in source_obj.values(): _rename_data_type(value) - if isinstance(source_obj, (list, tuple)): + if isinstance(source_obj, list | tuple): for value in source_obj: _rename_data_type(value) diff --git a/msticpy/context/__init__.py b/msticpy/context/__init__.py index 8bb52c998..42fb0b2e6 100644 --- a/msticpy/context/__init__.py +++ b/msticpy/context/__init__.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Context Providers Subpackage.""" + from __future__ import annotations from typing import Any diff --git a/msticpy/context/azure/azure_data.py b/msticpy/context/azure/azure_data.py index 29375b707..19d227ae8 100644 --- a/msticpy/context/azure/azure_data.py +++ b/msticpy/context/azure/azure_data.py @@ -5,17 +5,19 @@ # license information. # -------------------------------------------------------------------------- """Uses the Azure Python SDK to collect and return details related to Azure.""" + from __future__ import annotations import datetime import logging +from collections.abc import Callable, Iterable from dataclasses import asdict, dataclass, field from importlib.metadata import version -from typing import TYPE_CHECKING, Any, Callable, Iterable +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd -from packaging.version import Version, parse +from packaging.version import Version, parse # pylint: disable=no-name-in-module from typing_extensions import Self from ..._version import VERSION @@ -249,9 +251,7 @@ def get_subscriptions(self: Self) -> pd.DataFrame: """ if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -306,9 +306,7 @@ def get_subscription_info(self: Self, sub_id: str) -> dict: """ if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -412,9 +410,7 @@ def get_resources( """ # Check if connection and client required are already present if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -519,9 +515,7 @@ def get_resource_details( """ # Check if connection and client required are already present if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -671,9 +665,7 @@ def _get_api( """ # Check if connection and client required are already present if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -710,9 +702,7 @@ def _get_api( err_msg = "Resource provider not found" raise MsticpyResourceError(err_msg) - api_version = [ - v for v in resource_types.api_versions if "preview" not in v.lower() - ] + api_version = [v for v in resource_types.api_versions if "preview" not in v.lower()] if api_version is None or not api_version: api_ver = resource_types.api_versions[0] else: @@ -742,9 +732,7 @@ def get_network_details( """ # Check if connection and client required are already present if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -793,9 +781,7 @@ def get_network_details( ), subnet=ip_addr.subnet.name if ip_addr.subnet else None, subnet_nsg=( - ip_addr.subnet.network_security_group - if ip_addr.subnet - else None + ip_addr.subnet.network_security_group if ip_addr.subnet else None ), subnet_route_table=( ip_addr.subnet.route_table if ip_addr.subnet else None @@ -818,7 +804,7 @@ def get_network_details( ) nsg_rules = [] if nsg_details is not None: - for nsg in nsg_details.default_security_rules: # type: ignore + for nsg in nsg_details.default_security_rules: # type: ignore[union-attr] rules = asdict( NsgItems( rule_name=nsg.name, @@ -951,9 +937,7 @@ def _get_compute_state( """ if self.connected is False: - err_msg: str = ( - "You need to connect to the service before using this function." - ) + err_msg: str = "You need to connect to the service before using this function." raise MsticpyNotConnectedError( err_msg, help_uri=MsticpyAzureConfigError.DEF_HELP_URI, @@ -1050,9 +1034,7 @@ def _legacy_auth(self: Self, client_name: str, sub_id: str | None = None) -> Non """ if not self.credentials: - err_msg: str = ( - "Credentials must be provided for legacy authentication to work." - ) + err_msg: str = "Credentials must be provided for legacy authentication to work." raise ValueError(err_msg) client: type[ SubscriptionClient diff --git a/msticpy/context/azure/sentinel_analytics.py b/msticpy/context/azure/sentinel_analytics.py index f0c94b8f3..398a09872 100644 --- a/msticpy/context/azure/sentinel_analytics.py +++ b/msticpy/context/azure/sentinel_analytics.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Analytics Features.""" + from __future__ import annotations import logging -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from uuid import UUID, uuid4 import httpx @@ -48,9 +50,7 @@ def list_hunting_queries(self: Self) -> pd.DataFrame: item_type="ss_path", api_version="2020-08-01", ) - return saved_query_df[ - saved_query_df["properties.category"] == "Hunting Queries" - ] + return saved_query_df[saved_query_df["properties.category"] == "Hunting Queries"] get_hunting_queries: Callable[..., pd.DataFrame] = list_hunting_queries @@ -198,9 +198,7 @@ def create_analytic_rule( # pylint: disable=too-many-arguments, too-many-locals if template: template_id: str = self._get_template_id(template) templates: pd.DataFrame = self.list_analytic_templates() - template_details: pd.Series = templates[ - templates["name"] == template_id - ].iloc[0] + template_details: pd.Series = templates[templates["name"] == template_id].iloc[0] name = template_details["properties.displayName"] query = template_details["properties.query"] query_frequency = template_details["properties.queryFrequency"] diff --git a/msticpy/context/azure/sentinel_bookmarks.py b/msticpy/context/azure/sentinel_bookmarks.py index 07660ae9f..770a29335 100644 --- a/msticpy/context/azure/sentinel_bookmarks.py +++ b/msticpy/context/azure/sentinel_bookmarks.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Bookmark Features.""" + from __future__ import annotations import logging -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from uuid import UUID, uuid4 import httpx @@ -179,10 +181,7 @@ def _get_bookmark_id(self: Self, bookmark: str) -> str: display(filtered_bookmarks[["name", "properties.displayName"]]) err_msg: str = "More than one incident found, please specify by GUID" raise MsticpyUserError(err_msg) from bkmark_name - if ( - not isinstance(filtered_bookmarks, pd.DataFrame) - or filtered_bookmarks.empty - ): + if not isinstance(filtered_bookmarks, pd.DataFrame) or filtered_bookmarks.empty: err_msg = f"Incident {bookmark} not found" raise MsticpyUserError(err_msg) from bkmark_name return filtered_bookmarks["name"].iloc[0] diff --git a/msticpy/context/azure/sentinel_core.py b/msticpy/context/azure/sentinel_core.py index 97c9ed15d..6491b19bf 100644 --- a/msticpy/context/azure/sentinel_core.py +++ b/msticpy/context/azure/sentinel_core.py @@ -4,12 +4,14 @@ # license information. # -------------------------------------------------------------------------- """Uses the Microsoft Sentinel APIs to interact with Microsoft Sentinel Workspaces.""" + from __future__ import annotations import logging import warnings +from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -285,10 +287,8 @@ def connect( # noqa:PLR0913 sentinel_instance: SentinelInstanceDetails = SentinelInstanceDetails( subscription_id=connect_kwargs.get(_SUB_ID) or self.default_subscription_id, - resource_group=connect_kwargs.get(_RES_GRP) - or self.default_resource_group, - workspace_name=connect_kwargs.get(_WS_NAME) - or self.default_workspace_name, + resource_group=connect_kwargs.get(_RES_GRP) or self.default_resource_group, + workspace_name=connect_kwargs.get(_WS_NAME) or self.default_workspace_name, ) except TypeError as err: raise MsticpyUserConfigError( @@ -318,9 +318,7 @@ def connect( # noqa:PLR0913 ) logger.info("Using tenant id %s", tenant_id) az_connect_kwargs: dict[str, Any] = { - key: value - for key, value in connect_kwargs.items() - if key not in _WS_PARAMETERS + key: value for key, value in connect_kwargs.items() if key not in _WS_PARAMETERS } if tenant_id: az_connect_kwargs["tenant_id"] = tenant_id diff --git a/msticpy/context/azure/sentinel_dynamic_summary.py b/msticpy/context/azure/sentinel_dynamic_summary.py index c081dde72..fcc75f52c 100644 --- a/msticpy/context/azure/sentinel_dynamic_summary.py +++ b/msticpy/context/azure/sentinel_dynamic_summary.py @@ -4,11 +4,13 @@ # license information. # -------------------------------------------------------------------------- """Sentinel Dynamic Summary Mixin class.""" + from __future__ import annotations import logging +from collections.abc import Callable, Iterable from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any, Callable, Iterable +from typing import TYPE_CHECKING, Any import httpx from typing_extensions import Self @@ -141,9 +143,7 @@ def get_dynamic_summary( if summary_items: if not self.sent_data_query: try: - self.sent_data_query: ( - SentinelQueryProvider | None - ) = SentinelQueryProvider( + self.sent_data_query: SentinelQueryProvider | None = SentinelQueryProvider( self.default_workspace_name, # type: ignore[attr-defined] ) logger.info( @@ -315,9 +315,7 @@ def _create_dynamic_summary( "_create_dynamic_summary (DynamicSummary) failure %s", response.content.decode("utf-8"), ) - err_msg = ( - f"Dynamic summary create/update failed with status {response.status_code}" - ) + err_msg = f"Dynamic summary create/update failed with status {response.status_code}" raise MsticpyAzureConnectionError( err_msg, "Text response:", @@ -507,9 +505,7 @@ def update_dynamic_summary( # pylint:disable=too-many-arguments # noqa:PLR0913 If API returns an error. """ - if (summary and not summary.summary_id) or ( - data is not None and not summary_id - ): + if (summary and not summary.summary_id) or (data is not None and not summary_id): err_msg: str = "You must supply a summary ID to update" raise MsticpyParameterError( err_msg, diff --git a/msticpy/context/azure/sentinel_dynamic_summary_types.py b/msticpy/context/azure/sentinel_dynamic_summary_types.py index f0540b4e7..15acc1515 100644 --- a/msticpy/context/azure/sentinel_dynamic_summary_types.py +++ b/msticpy/context/azure/sentinel_dynamic_summary_types.py @@ -4,15 +4,17 @@ # license information. # -------------------------------------------------------------------------- """Sentinel Dynamic Summary classes.""" + from __future__ import annotations import dataclasses import json import logging import uuid +from collections.abc import Callable from datetime import datetime from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import numpy as np import pandas as pd @@ -78,9 +80,7 @@ def __init__(self: FieldList, fieldnames: Iterable[str]) -> None: def __repr__(self: Self) -> str: """Return list of field attributes and values.""" - field_names: str = "\n ".join( - f"{key}='{val}'" for key, val in vars(self).items() - ) + field_names: str = "\n ".join(f"{key}='{val}'" for key, val in vars(self).items()) return f"Fields:\n {field_names}" @@ -231,9 +231,7 @@ def __init__( # pylint:disable=too-many-arguments #noqa:PLR0913 if summary_items is not None: self.add_summary_items(summary_items) self.source_info: dict[str, Any] = ( - source_info - if isinstance(source_info, dict) - else {"user_source": source_info} + source_info if isinstance(source_info, dict) else {"user_source": source_info} ) self.source_info["source_pkg"] = f"MSTICPy {VERSION}" @@ -369,9 +367,7 @@ def df_to_dynamic_summaries(data: pd.DataFrame) -> list[DynamicSummary]: dyn_summaries = df_to_dynamic_summaries(data) """ - return [ - df_to_dynamic_summary(ds_data) for _, ds_data in data.groupby("SummaryId") - ] + return [df_to_dynamic_summary(ds_data) for _, ds_data in data.groupby("SummaryId")] @staticmethod def df_to_dynamic_summary(data: pd.DataFrame) -> DynamicSummary: @@ -819,8 +815,4 @@ def _convert_data_types( def _match_tactics(tactics: Iterable[str]) -> list[str]: """Return case-insensitive matches for tactics list.""" - return [ - _TACTICS_DICT[tactic.casefold()] - for tactic in tactics - if tactic in _TACTICS_DICT - ] + return [_TACTICS_DICT[tactic.casefold()] for tactic in tactics if tactic in _TACTICS_DICT] diff --git a/msticpy/context/azure/sentinel_incidents.py b/msticpy/context/azure/sentinel_incidents.py index 03b6609e8..dce3f57e0 100644 --- a/msticpy/context/azure/sentinel_incidents.py +++ b/msticpy/context/azure/sentinel_incidents.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Incident Features.""" + from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 import httpx @@ -238,9 +240,9 @@ def get_incident_bookmarks(self: Self, incident: str) -> list: ): bkmark_id: str = relationship["properties"]["relatedResourceName"] bookmarks_df: pd.DataFrame = self.list_bookmarks() - bookmark: pd.Series = bookmarks_df[ - bookmarks_df["name"] == bkmark_id - ].iloc[0] + bookmark: pd.Series = bookmarks_df[bookmarks_df["name"] == bkmark_id].iloc[ + 0 + ] bookmarks_list.append( { "Bookmark ID": bkmark_id, @@ -431,10 +433,7 @@ def _get_incident_id(self: Self, incident: str) -> str: display(filtered_incidents[["name", "properties.title"]]) err_msg: str = "More than one incident found, please specify by GUID" raise MsticpyUserError(err_msg) from incident_name - if ( - not isinstance(filtered_incidents, pd.DataFrame) - or filtered_incidents.empty - ): + if not isinstance(filtered_incidents, pd.DataFrame) or filtered_incidents.empty: err_msg = f"Incident {incident} not found" raise MsticpyUserError(err_msg) from incident_name return filtered_incidents["name"].iloc[0] @@ -462,9 +461,7 @@ def post_comment( """ self.check_connected() - comment_url: str = ( - self.sent_urls["incidents"] + f"/{incident_id}/comments/{uuid4()}" - ) + comment_url: str = self.sent_urls["incidents"] + f"/{incident_id}/comments/{uuid4()}" params: dict[str, str] = {"api-version": "2020-01-01"} data: dict[str, Any] = extract_sentinel_response({"message": comment}) if not self._token: diff --git a/msticpy/context/azure/sentinel_search.py b/msticpy/context/azure/sentinel_search.py index 08c08d793..20af405d3 100644 --- a/msticpy/context/azure/sentinel_search.py +++ b/msticpy/context/azure/sentinel_search.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Search Features.""" + from __future__ import annotations import datetime as dt @@ -71,8 +72,7 @@ def create_search( # noqa: PLR0913 search_start: dt.datetime = start or (search_end - dt.timedelta(days=90)) search_name = (search_name or str(uuid4())).replace("_", "") search_url: str = ( - self.sent_urls["search"] - + f"/{search_name}_SRCH?api-version=2021-12-01-preview" + self.sent_urls["search"] + f"/{search_name}_SRCH?api-version=2021-12-01-preview" ) search_items: dict[str, dict[str, Any]] = { "searchResults": { @@ -118,8 +118,7 @@ def check_search_status(self: Self, search_name: str) -> bool: """ search_name = search_name.strip("_SRCH") search_url: str = ( - self.sent_urls["search"] - + f"/{search_name}_SRCH?api-version=2021-12-01-preview" + self.sent_urls["search"] + f"/{search_name}_SRCH?api-version=2021-12-01-preview" ) if not self._token: err_msg = "Token not found, can't check search status." @@ -131,9 +130,7 @@ def check_search_status(self: Self, search_name: str) -> bool: if not search_check_response.is_success: raise CloudError(response=search_check_response) - check_result: str = search_check_response.json()["properties"][ - "provisioningState" - ] + check_result: str = search_check_response.json()["properties"]["provisioningState"] logger.info("%s_SRCH status is '%s'", search_name, check_result) return check_result == "Succeeded" @@ -154,8 +151,7 @@ def delete_search(self: Self, search_name: str) -> None: """ search_name = search_name.strip("_SRCH") search_url: str = ( - self.sent_urls["search"] - + f"/{search_name}_SRCH?api-version=2021-12-01-preview" + self.sent_urls["search"] + f"/{search_name}_SRCH?api-version=2021-12-01-preview" ) if not self._token: err_msg = "Token not found, can't delete search." diff --git a/msticpy/context/azure/sentinel_ti.py b/msticpy/context/azure/sentinel_ti.py index cb96269b4..99a61d6b9 100644 --- a/msticpy/context/azure/sentinel_ti.py +++ b/msticpy/context/azure/sentinel_ti.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Analytics Features.""" + from __future__ import annotations import datetime as dt diff --git a/msticpy/context/azure/sentinel_utils.py b/msticpy/context/azure/sentinel_utils.py index 004808783..c759030e3 100644 --- a/msticpy/context/azure/sentinel_utils.py +++ b/msticpy/context/azure/sentinel_utils.py @@ -4,12 +4,13 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Utilties.""" + from __future__ import annotations import logging from collections import Counter from dataclasses import dataclass -from typing import Any, Dict, cast +from typing import Any, cast import httpx import pandas as pd @@ -223,8 +224,7 @@ def check_connected(self: Self) -> None: """Check that Sentinel workspace is connected.""" if not self.connected: err_msg: str = ( - "Not connected to Sentinel, ensure you run `.connect`" - "before calling functions." + "Not connected to Sentinel, ensure you run `.connect`before calling functions." ) raise MsticpyAzureConnectionError(err_msg) @@ -314,7 +314,7 @@ def extract_sentinel_response( """ data_body: dict[str, dict[str, str]] = {"properties": {}} - for key in items: + for key in items: # noqa: PLC0206 if key in ["severity", "status", "title", "message", "searchResults"] or props: data_body["properties"].update({key: items[key]}) else: @@ -341,9 +341,7 @@ def parse_resource_id(res_id: str) -> dict[str, Any]: """Extract components from workspace resource ID.""" if not res_id.startswith("/"): res_id = f"/{res_id}" - res_id_parts: dict[str, str] = cast( - Dict[str, str], az_tools.parse_resource_id(res_id) - ) + res_id_parts: dict[str, str] = cast(dict[str, str], az_tools.parse_resource_id(res_id)) workspace_name: str | None = None if ( res_id_parts.get("namespace") == "Microsoft.OperationalInsights" diff --git a/msticpy/context/azure/sentinel_watchlists.py b/msticpy/context/azure/sentinel_watchlists.py index a730366e3..614f586cd 100644 --- a/msticpy/context/azure/sentinel_watchlists.py +++ b/msticpy/context/azure/sentinel_watchlists.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Mixin Classes for Sentinel Watchlist Features.""" + from __future__ import annotations import logging @@ -224,17 +225,17 @@ def add_watchlist_item( axis=1, copy=False, ) - if (current_df == item_series).all( - axis=1 - ).any() and overwrite: # type: ignore[attr-defined] + if (current_df == item_series).all(axis=1).any() and overwrite: watchlist_id: str = current_items[ current_items.isin(list(new_item.values())).any(axis=1) ]["properties.watchlistItemId"].iloc[0] # If not in watchlist already generate new ID - elif not (current_df == item_series).all(axis=1).any(): # type: ignore[attr-defined] + elif not (current_df == item_series).all(axis=1).any(): watchlist_id = str(uuid4()) else: - err_msg = "Item already exists in the watchlist. Set overwrite = True to replace." + err_msg = ( + "Item already exists in the watchlist. Set overwrite = True to replace." + ) raise MsticpyUserError(err_msg) watchlist_url: str = ( diff --git a/msticpy/context/azure/sentinel_workspaces.py b/msticpy/context/azure/sentinel_workspaces.py index d5d74f584..f914c59ff 100644 --- a/msticpy/context/azure/sentinel_workspaces.py +++ b/msticpy/context/azure/sentinel_workspaces.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Mixin Class for Sentinel Workspaces.""" + from __future__ import annotations import logging diff --git a/msticpy/context/contextlookup.py b/msticpy/context/contextlookup.py index 631f2f352..e6ad41d45 100644 --- a/msticpy/context/contextlookup.py +++ b/msticpy/context/contextlookup.py @@ -12,9 +12,11 @@ requests per minute for the account type that you have. """ + from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Iterable, Mapping +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, ClassVar from typing_extensions import Self @@ -36,9 +38,7 @@ class ContextLookup(Lookup): """Observable lookup from providers.""" - _NO_PROVIDERS_MSG: ClassVar[ - str - ] = """ + _NO_PROVIDERS_MSG: ClassVar[str] = """ No Context Providers are loaded - please check that you have correctly configured your msticpyconfig.yaml settings. """ @@ -171,7 +171,7 @@ async def _lookup_observables_async( # pylint:disable=too-many-arguments # noqa ) -> pd.DataFrame: """Lookup items async.""" return await self._lookup_items_async( - data, # type: ignore[arg-type] + data, item_col=obs_col, item_type_col=obs_type_col, query_type=query_type, diff --git a/msticpy/context/contextproviders/context_provider_base.py b/msticpy/context/contextproviders/context_provider_base.py index 4b6f9b1bf..8e208957d 100644 --- a/msticpy/context/contextproviders/context_provider_base.py +++ b/msticpy/context/contextproviders/context_provider_base.py @@ -12,12 +12,14 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import re from abc import abstractmethod +from collections.abc import Iterable from ipaddress import IPv4Address, IPv6Address, ip_address -from typing import TYPE_CHECKING, Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import Self diff --git a/msticpy/context/contextproviders/http_context_provider.py b/msticpy/context/contextproviders/http_context_provider.py index 7efb9f86c..1f92d6ed7 100644 --- a/msticpy/context/contextproviders/http_context_provider.py +++ b/msticpy/context/contextproviders/http_context_provider.py @@ -11,6 +11,7 @@ It inherits from ContextProvider and HttpProvider """ + from __future__ import annotations from functools import lru_cache @@ -126,9 +127,7 @@ def _run_context_lookup_query( result["RawResult"] = response.json().copy() result["Result"], result["Details"] = self.parse_results(result) except JSONDecodeError: - result[ - "RawResult" - ] = f"""There was a problem parsing results from this lookup: + result["RawResult"] = f"""There was a problem parsing results from this lookup: {response.text}""" result["Result"] = False result["Details"] = {} @@ -139,7 +138,7 @@ def _run_context_lookup_query( result["Details"] = self._response_message(result["Status"]) return result - @lru_cache(maxsize=256) + @lru_cache(maxsize=256) # noqa: B019 def lookup_observable( # noqa:PLR0913 self: Self, observable: str, diff --git a/msticpy/context/contextproviders/servicenow.py b/msticpy/context/contextproviders/servicenow.py index 8f52f8c3e..6c4739e3b 100644 --- a/msticpy/context/contextproviders/servicenow.py +++ b/msticpy/context/contextproviders/servicenow.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import datetime as dt @@ -155,9 +156,7 @@ def parse_results(self: Self, response: dict[str, Any]) -> tuple[bool, Any]: if result.get("sys_created_on") else "" ), - **( - getattr(self, f"_parse_result_{response['ObservableType']}")(result) - ), + **(getattr(self, f"_parse_result_{response['ObservableType']}")(result)), } for result in results ] diff --git a/msticpy/context/domain_utils.py b/msticpy/context/domain_utils.py index 8d7e3689f..068d9636d 100644 --- a/msticpy/context/domain_utils.py +++ b/msticpy/context/domain_utils.py @@ -10,6 +10,7 @@ with a domain or url, such as getting a screenshot or validating the TLD. """ + from __future__ import annotations import datetime as dt @@ -17,9 +18,10 @@ import logging import ssl import time +from collections.abc import Callable from dataclasses import asdict from enum import Enum -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from urllib.error import HTTPError, URLError import httpx diff --git a/msticpy/context/geoip.py b/msticpy/context/geoip.py index 8fe719ae6..1f91eb49e 100644 --- a/msticpy/context/geoip.py +++ b/msticpy/context/geoip.py @@ -20,6 +20,7 @@ an online lookup (API key required). """ + from __future__ import annotations import contextlib @@ -30,11 +31,12 @@ import warnings from abc import ABCMeta, abstractmethod from collections import abc +from collections.abc import Iterable, Mapping from datetime import datetime, timedelta, timezone from json import JSONDecodeError from pathlib import Path from time import sleep -from typing import Any, ClassVar, Iterable, Mapping +from typing import Any, ClassVar import geoip2.database import httpx @@ -206,15 +208,11 @@ class IPStackLookup(GeoIpLookup): """ - _LICENSE_HTML: ClassVar[ - str - ] = """ + _LICENSE_HTML: ClassVar[str] = """ This library uses services provided by ipstack. https://ipstack.com""" - _LICENSE_TXT: ClassVar[ - str - ] = """ + _LICENSE_TXT: ClassVar[str] = """ This library uses services provided by ipstack (https://ipstack.com)""" _IPSTACK_API: ClassVar[str] = ( @@ -384,9 +382,7 @@ def _submit_request( # Please upgrade your subscription."}} if "success" in results and not results["success"]: - err_msg: str = ( - f"Service unable to complete request. Error: {results['error']}" - ) + err_msg: str = f"Service unable to complete request. Error: {results['error']}" raise PermissionError(err_msg) return [(item, response.status_code) for item in results.values()] @@ -436,8 +432,7 @@ class GeoLiteLookup(GeoIpLookup): """ _MAXMIND_DOWNLOAD: ClassVar[str] = ( - "https://download.maxmind.com/geoip/databases" - "/GeoLite2-City/download?suffix=tar.gz" + "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz" ) _DB_HOME: ClassVar[str] = str( @@ -446,23 +441,17 @@ class GeoLiteLookup(GeoIpLookup): _DB_ARCHIVE: ClassVar[str] = "GeoLite2-City.mmdb.{rand}.tar.gz" _DB_FILE: ClassVar[str] = "GeoLite2-City.mmdb" - _LICENSE_HTML: ClassVar[ - str - ] = """ + _LICENSE_HTML: ClassVar[str] = """ This product includes GeoLite2 data created by MaxMind, available from https://www.maxmind.com. """ - _LICENSE_TXT: ClassVar[ - str - ] = """ + _LICENSE_TXT: ClassVar[str] = """ This product includes GeoLite2 data created by MaxMind, available from https://www.maxmind.com. """ - _NO_API_KEY_MSSG: ClassVar[ - str - ] = """ + _NO_API_KEY_MSSG: ClassVar[str] = """ You need both an API Key and an Account ID to download the Maxmind GeoIPLite database. If you do not have an account, go here to create one and obtain and API key and your account ID. @@ -589,7 +578,7 @@ def lookup_ip( try: geo_match_object = self._reader.city(ip_input) if hasattr(geo_match_object, "raw"): - geo_match = geo_match_object.raw # type: ignore + geo_match = geo_match_object.raw elif hasattr(geo_match_object, "to_dict"): geo_match = geo_match_object.to_dict() else: @@ -705,8 +694,7 @@ def _check_and_update_db(self: Self) -> None: db_updated = False elif self._force_update: logger.info( - "force_update is set to True. " - "Attempting to download new database to %s", + "force_update is set to True. Attempting to download new database to %s", self._db_folder, ) if not self._download_and_extract_archive(): @@ -751,7 +739,8 @@ def _download_and_extract_archive(self: Self) -> bool: return True # Create a basic auth object for the request basic_auth = httpx.BasicAuth( - username=self._account_id, password=self._api_key # type: ignore[arg-type] + username=self._account_id, # type: ignore[arg-type] + password=self._api_key, # type: ignore[arg-type] ) # Stream download and write to file logger.info( @@ -969,9 +958,7 @@ def entity_distance(ip_src: IpAddress, ip_dest: IpAddress) -> float: """ if not ip_src.Location or not ip_dest.Location: - err_msg: str = ( - "Source and destination entities must have defined Location properties." - ) + err_msg: str = "Source and destination entities must have defined Location properties." raise AttributeError(err_msg) return geo_distance( diff --git a/msticpy/context/http_provider.py b/msticpy/context/http_provider.py index 81e5663de..d4ee9ed86 100644 --- a/msticpy/context/http_provider.py +++ b/msticpy/context/http_provider.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import traceback @@ -151,9 +152,7 @@ def __init__( self._request_params["Instance"] = Instance.strip() missing_params: list[str] = [ - param - for param in self._REQUIRED_PARAMS - if param not in self._request_params + param for param in self._REQUIRED_PARAMS if param not in self._request_params ] missing_params = [] diff --git a/msticpy/context/ip_utils.py b/msticpy/context/ip_utils.py index 41ef1e892..54da8d019 100644 --- a/msticpy/context/ip_utils.py +++ b/msticpy/context/ip_utils.py @@ -12,23 +12,23 @@ Designed to support any data source containing IP address entity. """ + from __future__ import annotations import ipaddress import logging import re import socket -import warnings +from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass, field from functools import lru_cache from time import sleep -from typing import Any, Callable, Iterator +from typing import Any import httpx import pandas as pd from bs4 import BeautifulSoup from deprecated.sphinx import deprecated -from typing_extensions import Self from .._version import VERSION from ..common.exceptions import MsticpyConnectionError, MsticpyException @@ -80,9 +80,9 @@ def _get_asns_dict() -> dict[str, str]: raise MsticpyConnectionError(err_msg) from err asns_soup = BeautifulSoup(asns_resp.content, features="lxml") asns_dict = { - str(asn.next_element) - .strip(): str(asn.next_element.next_element if asn.next_element else "") - .strip() + str(asn.next_element).strip(): str( + asn.next_element.next_element if asn.next_element else "" + ).strip() for asn in asns_soup.find_all("a") } return asns_dict @@ -132,7 +132,7 @@ def convert_to_ip_entities( """ # locally imported to prevent cyclic import # pylint: disable=import-outside-toplevel, cyclic-import - from .geoip import GeoLiteLookup + from .geoip import GeoLiteLookup # noqa: PLC0415 geo_lite_lookup: GeoLiteLookup = GeoLiteLookup() @@ -391,64 +391,6 @@ def get_whois_df( # noqa: PLR0913 return data.assign(ASNDescription="No data returned") -@pd.api.extensions.register_dataframe_accessor("mp_whois") -@export -class IpWhoisAccessor: - """Pandas api extension for IP Whois lookup.""" - - def __init__(self: IpWhoisAccessor, pandas_obj: pd.DataFrame) -> None: - """Instantiate pandas extension class.""" - self._df: pd.DataFrame = pandas_obj - - def lookup( - self: Self, - ip_column: str, - *, - asn_col: str = "ASNDescription", - whois_col: str = "WhoIsData", - show_progress: bool = False, - ) -> pd.DataFrame: - """ - Extract IoCs from either a pandas DataFrame. - - Parameters - ---------- - ip_column : str - Column name of IP Address to look up. - asn_col : str, optional - Name of the output column for ASN description, - by default "ASNDescription" - whois_col : str, optional - Name of the output column for full whois data, - by default "WhoIsData" - show_progress : bool, optional - Show progress for each query, by default False - - Returns - ------- - pd.DataFrame - Output DataFrame with results in added columns. - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use IpAddress.util.whois() pivot function." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn( - warn_message, - category=DeprecationWarning, - stacklevel=1, - ) - return get_whois_df( - data=self._df, - ip_column=ip_column, - asn_col=asn_col, - whois_col=whois_col, - show_progress=show_progress, - ) - - def ip_whois( ip: IpAddress | str | list | pd.Series | None = None, ip_address: IpAddress | str | list[str] | pd.Series | None = None, @@ -489,7 +431,7 @@ def ip_whois( if ip is None: err_msg: str = "One of ip or ip_address parameters must be supplied." raise ValueError(err_msg) - if isinstance(ip, (list, pd.Series)): + if isinstance(ip, list | pd.Series): rate_limit: bool = len(ip) > RATE_LIMIT_THRESHOLD if rate_limit: logger.info("Large number of lookups, this may take some time.") @@ -497,13 +439,13 @@ def ip_whois( for ip_addr in ip: if rate_limit: sleep(query_rate) - whois_results[ip_addr] = _whois_lookup( # type: ignore[index] + whois_results[ip_addr] = _whois_lookup( ip_addr, raw=raw, retry_count=retry_count, ).properties return _whois_result_to_pandas(whois_results) - if isinstance(ip, (str, IpAddress)): + if isinstance(ip, str | IpAddress): return _whois_lookup(ip, raw=raw) return pd.DataFrame() @@ -525,9 +467,7 @@ def get_asn_details(asns: str | list[str]) -> pd.DataFrame | dict[str, Any]: """ if isinstance(asns, list): - asn_detail_results: list[dict[str, Any]] = [ - _asn_results(str(asn)) for asn in asns - ] + asn_detail_results: list[dict[str, Any]] = [_asn_results(str(asn)) for asn in asns] return pd.DataFrame(asn_detail_results) return _asn_results(str(asns)) @@ -598,7 +538,7 @@ def get_asn_from_ip( ip_response: str = _cymru_query(query) keys: list[str] = ip_response.split("\n", maxsplit=1)[0].split("|") values: list[str] = ip_response.split("\n")[1].split("|") - return {key.strip(): value.strip() for key, value in zip(keys, values)} + return {key.strip(): value.strip() for key, value in zip(keys, values, strict=False)} @dataclass @@ -695,8 +635,7 @@ def _rdap_lookup(url: str, retry_count: int = 5) -> httpx.Response: retry_count -= 1 if not rdap_data: err_msg: str = ( - "Rate limit exceeded - try adjusting query_rate parameter " - "to slow down requests" + "Rate limit exceeded - try adjusting query_rate parameter to slow down requests" ) raise MsticpyException(err_msg) return rdap_data diff --git a/msticpy/context/lookup.py b/msticpy/context/lookup.py index 18c8065e9..21a2e8bce 100644 --- a/msticpy/context/lookup.py +++ b/msticpy/context/lookup.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import asyncio @@ -19,7 +20,8 @@ import logging import warnings from collections import ChainMap -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar import nest_asyncio import pandas as pd @@ -81,16 +83,12 @@ async def get_remaining(self: Self) -> int: class Lookup: """Item lookup from providers.""" - _NO_PROVIDERS_MSG: ClassVar[ - str - ] = """ + _NO_PROVIDERS_MSG: ClassVar[str] = """ No Providers are loaded - please check that you have correctly configured your msticpyconfig.yaml settings. """ - _HELP_URI: ClassVar[str] = ( - "https://msticpy.readthedocs.io/en/latest/DataEnrichment.html" - ) + _HELP_URI: ClassVar[str] = "https://msticpy.readthedocs.io/en/latest/DataEnrichment.html" PROVIDERS: ClassVar[dict[str, tuple[str, str]]] = {} CUSTOM_PROVIDERS: ClassVar[dict[str, type[Provider]]] @@ -128,14 +126,6 @@ def __init__( if primary_providers: for prov in primary_providers: self.add_provider(prov, primary=True) - if secondary_providers: - warnings.warn( - "'secondary_providers' is a deprecated parameter", - DeprecationWarning, - stacklevel=1, - ) - for prov in secondary_providers: - self.add_provider(prov, primary=False) if not (primary_providers or secondary_providers): self._load_providers() @@ -221,13 +211,10 @@ def enable_provider(self: Self, providers: str | Iterable[str]) -> None: as_list=True, ) if not available_providers: - err_msg: str = ( - f"Unknown provider '{provider}'. No available providers." - ) + err_msg: str = f"Unknown provider '{provider}'. No available providers." else: - err_msg = ( - f"Unknown provider '{provider}'. Available providers:" - ", ".join(available_providers) + err_msg = f"Unknown provider '{provider}'. Available providers:, ".join( + available_providers ) raise ValueError(err_msg) @@ -259,13 +246,10 @@ def disable_provider(self: Self, providers: str | Iterable[str]) -> None: as_list=True, ) if not available_providers: - err_msg: str = ( - f"Unknown provider '{provider}'. No available providers." - ) + err_msg: str = f"Unknown provider '{provider}'. No available providers." else: - err_msg = ( - f"Unknown provider '{provider}'. Available providers:" - ", ".join(available_providers) + err_msg = f"Unknown provider '{provider}'. Available providers:, ".join( + available_providers ) raise ValueError(err_msg) @@ -771,9 +755,8 @@ def import_provider(cls: type[Self], provider: str) -> type[Provider]: if not (mod_name and cls_name): if hasattr(cls, "CUSTOM_PROVIDERS") and provider in cls.CUSTOM_PROVIDERS: return cls.CUSTOM_PROVIDERS[provider] - err_msg: str = ( - f"No provider named '{provider}'. Possible values are: " - ", ".join(list(cls.PROVIDERS) + list(cls.CUSTOM_PROVIDERS)) + err_msg: str = f"No provider named '{provider}'. Possible values are: , ".join( + list(cls.PROVIDERS) + list(cls.CUSTOM_PROVIDERS) ) raise LookupError(err_msg) @@ -833,9 +816,7 @@ def _load_providers( # set the description from settings, if one is provided, otherwise # use class docstring. - provider_instance.description = ( - settings.description or provider_instance.__doc__ - ) + provider_instance.description = settings.description or provider_instance.__doc__ self.add_provider( provider=provider_instance, @@ -887,7 +868,7 @@ def _combine_results( ) -> pd.DataFrame: """Combine dataframe results into single DF.""" result_list: list[pd.DataFrame] = [] - for prov_name, provider_result in zip(provider_names, results): + for prov_name, provider_result in zip(provider_names, results, strict=False): if provider_result is None or provider_result.empty: continue result: pd.DataFrame = provider_result.copy() diff --git a/msticpy/context/lookup_result.py b/msticpy/context/lookup_result.py index 91220a694..565830a10 100644 --- a/msticpy/context/lookup_result.py +++ b/msticpy/context/lookup_result.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Lookup Status class.""" + from __future__ import annotations from enum import Enum diff --git a/msticpy/context/preprocess_observable.py b/msticpy/context/preprocess_observable.py index 7f297933c..9b12a6a6b 100644 --- a/msticpy/context/preprocess_observable.py +++ b/msticpy/context/preprocess_observable.py @@ -12,15 +12,17 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import contextlib import math import re from collections import Counter +from collections.abc import Callable from functools import partial from ipaddress import IPv4Address, IPv6Address, ip_address -from typing import Callable, ClassVar +from typing import ClassVar from urllib.parse import quote_plus from typing_extensions import Self @@ -371,6 +373,4 @@ def preprocess_observable( def _entropy(input_str: str) -> float: """Compute entropy of input string.""" str_len = float(len(input_str)) - return -sum( - (a / str_len) * math.log2(a / str_len) for a in Counter(input_str).values() - ) + return -sum((a / str_len) * math.log2(a / str_len) for a in Counter(input_str).values()) diff --git a/msticpy/context/provider_base.py b/msticpy/context/provider_base.py index 1f5bed6c5..eb5510fe8 100644 --- a/msticpy/context/provider_base.py +++ b/msticpy/context/provider_base.py @@ -12,15 +12,16 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import asyncio import logging from abc import ABC, abstractmethod from asyncio import get_event_loop -from collections.abc import Iterable as C_Iterable +from collections.abc import Coroutine, Generator, Iterable from functools import lru_cache, partial, singledispatch -from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generator, Iterable, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast import pandas as pd from typing_extensions import Self @@ -441,7 +442,7 @@ def register_pivots( @singledispatch def generate_items( - data: pd.DataFrame | dict | C_Iterable, + data: pd.DataFrame | dict | Iterable, item_col: str | None = None, item_type_col: str | None = None, ) -> Generator[tuple[str | None, str | None], Any, None]: @@ -464,7 +465,7 @@ def generate_items( """ del item_col, item_type_col - if isinstance(data, C_Iterable): + if isinstance(data, Iterable): for item in data: yield cast(str, item), Provider.resolve_item_type(item) else: @@ -500,7 +501,7 @@ def _( def _make_sync(future: Coroutine) -> pd.DataFrame: """Wait for an async call, making it sync.""" try: - event_loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() except RuntimeError: # Generate an event loop if there isn't any. event_loop = asyncio.new_event_loop() diff --git a/msticpy/context/tilookup.py b/msticpy/context/tilookup.py index 869975ea0..5c62d7f26 100644 --- a/msticpy/context/tilookup.py +++ b/msticpy/context/tilookup.py @@ -12,9 +12,11 @@ requests per minute for the account type that you have. """ + from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Iterable, Mapping +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, ClassVar from typing_extensions import Self @@ -39,9 +41,7 @@ class TILookup(Lookup): """Threat Intel observable lookup from providers.""" - _NO_PROVIDERS_MSG: ClassVar[ - str - ] = """ + _NO_PROVIDERS_MSG: ClassVar[str] = """ No TI Providers are loaded - please check that you have correctly configured your msticpyconfig.yaml settings. """ @@ -233,7 +233,7 @@ async def _lookup_iocs_async( # pylint: disable=too-many-arguments #noqa:PLR091 ) -> pd.DataFrame: """Lookup IoCs async.""" return await self._lookup_items_async( - data, # type: ignore[arg-type] + data, item_col=ioc_col, item_type_col=ioc_type_col, query_type=ioc_query_type, diff --git a/msticpy/context/tiproviders/__init__.py b/msticpy/context/tiproviders/__init__.py index 53cfe6eb9..c0efbab71 100644 --- a/msticpy/context/tiproviders/__init__.py +++ b/msticpy/context/tiproviders/__init__.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """TI Providers sub-package.""" + from __future__ import annotations from ..._version import VERSION diff --git a/msticpy/context/tiproviders/abuseipdb.py b/msticpy/context/tiproviders/abuseipdb.py index 72445524d..127bcd7b5 100644 --- a/msticpy/context/tiproviders/abuseipdb.py +++ b/msticpy/context/tiproviders/abuseipdb.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from typing import Any, ClassVar diff --git a/msticpy/context/tiproviders/alienvault_otx.py b/msticpy/context/tiproviders/alienvault_otx.py index 7f00a067d..01d3c92b8 100644 --- a/msticpy/context/tiproviders/alienvault_otx.py +++ b/msticpy/context/tiproviders/alienvault_otx.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from dataclasses import dataclass @@ -121,9 +122,7 @@ def parse_results(self: Self, response: dict) -> tuple[bool, ResultSeverity, Any "sections_available": response["RawResult"]["sections"], }, ) - severity = ( - ResultSeverity.warning if pulse_count == 1 else ResultSeverity.high - ) + severity = ResultSeverity.warning if pulse_count == 1 else ResultSeverity.high return ( True, severity, diff --git a/msticpy/context/tiproviders/azure_sent_byoti.py b/msticpy/context/tiproviders/azure_sent_byoti.py index 5b1ea6d1d..8031e60af 100644 --- a/msticpy/context/tiproviders/azure_sent_byoti.py +++ b/msticpy/context/tiproviders/azure_sent_byoti.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from typing import Any, ClassVar diff --git a/msticpy/context/tiproviders/binaryedge.py b/msticpy/context/tiproviders/binaryedge.py index a198eff1c..3c7eb4257 100644 --- a/msticpy/context/tiproviders/binaryedge.py +++ b/msticpy/context/tiproviders/binaryedge.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from typing import Any, ClassVar @@ -63,9 +64,7 @@ def parse_results(self: Self, response: dict) -> tuple[bool, ResultSeverity, Any service_details: dict[str, Any] = {} for result in data_point["results"]: if "service" in result["result"]["data"]: - service_details["Banner"] = result["result"]["data"]["service"][ - "banner" - ] + service_details["Banner"] = result["result"]["data"]["service"]["banner"] if "cert_info" in result["result"]["data"]: service_details["Cert Info"] = result["result"]["data"]["cert_info"] result_dict[data_point["port"]] = service_details diff --git a/msticpy/context/tiproviders/crowdsec.py b/msticpy/context/tiproviders/crowdsec.py index ed39de2db..9351425dd 100644 --- a/msticpy/context/tiproviders/crowdsec.py +++ b/msticpy/context/tiproviders/crowdsec.py @@ -74,10 +74,7 @@ def parse_results(self: Self, response: dict) -> tuple[bool, ResultSeverity, Any ], ), "Behaviors": ",".join( - [ - behavior["name"] - for behavior in response["RawResult"]["behaviors"] - ], + [behavior["name"] for behavior in response["RawResult"]["behaviors"]], ), }, ) diff --git a/msticpy/context/tiproviders/cyberint.py b/msticpy/context/tiproviders/cyberint.py index c14bfa61b..21c001734 100644 --- a/msticpy/context/tiproviders/cyberint.py +++ b/msticpy/context/tiproviders/cyberint.py @@ -10,6 +10,7 @@ multiple observables. Processing requires an API key. https://cyberint.com/ """ + from __future__ import annotations from typing import Any, ClassVar diff --git a/msticpy/context/tiproviders/greynoise.py b/msticpy/context/tiproviders/greynoise.py index 97001a4c8..8f2a96e53 100644 --- a/msticpy/context/tiproviders/greynoise.py +++ b/msticpy/context/tiproviders/greynoise.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from typing import Any, ClassVar diff --git a/msticpy/context/tiproviders/ibm_xforce.py b/msticpy/context/tiproviders/ibm_xforce.py index ba50a88c1..3371dd67b 100644 --- a/msticpy/context/tiproviders/ibm_xforce.py +++ b/msticpy/context/tiproviders/ibm_xforce.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from dataclasses import dataclass diff --git a/msticpy/context/tiproviders/intsights.py b/msticpy/context/tiproviders/intsights.py index 01dee64fd..46d9cce63 100644 --- a/msticpy/context/tiproviders/intsights.py +++ b/msticpy/context/tiproviders/intsights.py @@ -11,6 +11,7 @@ processing performance may be limited to a specific number of requests per minute for the account type that you have. """ + from __future__ import annotations import datetime as dt @@ -163,7 +164,9 @@ def parse_results(self: Self, response: dict) -> tuple[bool, ResultSeverity, Any else ( ResultSeverity.warning if sev == "Medium" - else ResultSeverity.high if sev == "High" else ResultSeverity.unknown + else ResultSeverity.high + if sev == "High" + else ResultSeverity.unknown ) ) diff --git a/msticpy/context/tiproviders/ip_quality_score.py b/msticpy/context/tiproviders/ip_quality_score.py index bc0388598..9911c7886 100644 --- a/msticpy/context/tiproviders/ip_quality_score.py +++ b/msticpy/context/tiproviders/ip_quality_score.py @@ -9,6 +9,7 @@ This provider offers contextual lookup services and fraud scoring for IP addresses. https://www.ipqualityscore.com/ """ + from __future__ import annotations from typing import Any, ClassVar diff --git a/msticpy/context/tiproviders/kql_base.py b/msticpy/context/tiproviders/kql_base.py index 54c351b7b..a5e39b7c7 100644 --- a/msticpy/context/tiproviders/kql_base.py +++ b/msticpy/context/tiproviders/kql_base.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import abc @@ -19,8 +20,9 @@ import logging import warnings from collections import defaultdict +from collections.abc import Callable, Iterable from functools import lru_cache -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar import pandas as pd from typing_extensions import Self @@ -36,8 +38,6 @@ if TYPE_CHECKING: import datetime as dt - - from Kqlmagic.results import ResultSet logger: logging.Logger = logging.getLogger(__name__) __version__ = VERSION __author__ = "Ian Hellen" @@ -85,7 +85,7 @@ def __init__( def _connected(self: Self) -> bool: return self._query_provider.connected - @lru_cache(maxsize=256) + @lru_cache(maxsize=256) # noqa: B019 def lookup_ioc( self: Self, ioc: str, @@ -254,32 +254,49 @@ def _add_failure_status( """Add status info, if query produced no results.""" src_ioc_frame["Result"] = False src_ioc_frame["Details"] = ( - "Query failure" - if lookup_status == LookupStatus.QUERY_FAILED - else "Not found" + "Query failure" if lookup_status == LookupStatus.QUERY_FAILED else "Not found" ) src_ioc_frame["Status"] = lookup_status.value src_ioc_frame["Severity"] = ResultSeverity.information.name @staticmethod - def _check_result_status(data_result: pd.DataFrame | ResultSet) -> LookupStatus: - """Check the return value from the query.""" + def _check_result_status(data_result: pd.DataFrame | Any) -> LookupStatus: + """ + Check the return value from the query. + + Parameters + ---------- + data_result : pd.DataFrame | Any + Query result - normally a DataFrame from azure-monitor-query driver, + but can be other types in test mocks or legacy code. + + Returns + ------- + LookupStatus + Status of the query execution. + + """ if isinstance(data_result, pd.DataFrame): return LookupStatus.NO_DATA if data_result.empty else LookupStatus.OK - if ( - hasattr(data_result, "completion_query_info") - and data_result.completion_query_info["StatusCode"] == 0 - and data_result.records_count == 0 + + # Handle legacy/mock objects with completion_query_info + if hasattr(data_result, "completion_query_info") and hasattr( + data_result, "records_count" ): - logger.info("No results return from data provider.") - return LookupStatus.NO_DATA - if data_result and hasattr(data_result, "completion_query_info"): + if ( + data_result.completion_query_info.get("StatusCode") == 0 + and data_result.records_count == 0 + ): + logger.info("No results returned from data provider.") + return LookupStatus.NO_DATA logger.info( - "No results returned from data provider. %s", + "Query failed. Status: %s", data_result.completion_query_info, ) - else: - logger.info("Unknown response from provider: %s", data_result) + return LookupStatus.QUERY_FAILED + + # Unknown result type + logger.warning("Unknown response type from provider: %s", type(data_result)) return LookupStatus.QUERY_FAILED @abc.abstractmethod diff --git a/msticpy/context/tiproviders/mblookup.py b/msticpy/context/tiproviders/mblookup.py index 9b203910b..5f6369f22 100644 --- a/msticpy/context/tiproviders/mblookup.py +++ b/msticpy/context/tiproviders/mblookup.py @@ -5,6 +5,7 @@ # Author: Thomas Roccia - @fr0gger_ # -------------------------------------------------------------------------- """MalwareBazaar TI Provider.""" + from __future__ import annotations from enum import Enum @@ -122,10 +123,8 @@ def lookup_ioc( """ if MBEntityType(mb_type) not in self._SUPPORTED_MB_TYPES: - err_msg: str = ( - f"Property type {mb_type} not supported." - " Valid types are " - ", ".join(x.value for x in MBEntityType.__members__.values()) + err_msg: str = f"Property type {mb_type} not supported. Valid types are , ".join( + x.value for x in MBEntityType.__members__.values() ) raise KeyError(err_msg) diff --git a/msticpy/context/tiproviders/open_page_rank.py b/msticpy/context/tiproviders/open_page_rank.py index ef1e9c0da..8764a1502 100644 --- a/msticpy/context/tiproviders/open_page_rank.py +++ b/msticpy/context/tiproviders/open_page_rank.py @@ -12,10 +12,12 @@ requests per minute for the account type that you have. """ + from __future__ import annotations +from collections.abc import Iterable from json import JSONDecodeError -from typing import Any, ClassVar, Iterable +from typing import Any, ClassVar import httpx import pandas as pd @@ -242,7 +244,8 @@ def _lookup_bulk_request(self: Self, ioc_list: Iterable[str]) -> Iterable[dict]: def _lookup_batch(self: Self, ioc_list: list) -> Iterable[dict]: # build the query string manually - of the form domains[N]=domN&domains[N+1]... qry_elements: list[str] = [ - f"domains[{idx}]={dom}" for idx, dom in zip(range(len(ioc_list)), ioc_list) + f"domains[{idx}]={dom}" + for idx, dom in zip(range(len(ioc_list)), ioc_list, strict=False) ] qry_str: str = "&".join(qry_elements) diff --git a/msticpy/context/tiproviders/pulsedive.py b/msticpy/context/tiproviders/pulsedive.py index dab0422f1..68c499eeb 100644 --- a/msticpy/context/tiproviders/pulsedive.py +++ b/msticpy/context/tiproviders/pulsedive.py @@ -5,6 +5,7 @@ # Author: Thomas Roccia - @fr0gger_ # -------------------------------------------------------------------------- """Pulsedive TI Provider.""" + from __future__ import annotations from enum import Enum @@ -30,6 +31,7 @@ __version__ = VERSION __author__ = "Thomas Roccia | @fr0gger_" +# pylint: disable=invalid-name _QUERY_OBJECTS_MAPPINGS: dict[str, dict[str, str]] = { "indicator": {"indicator": "observable"}, "threat": {"threat": "observable"}, @@ -346,9 +348,9 @@ class Pulsedive(HttpTIProvider): _BASE_URL = PDlookup.BASE_URL - _QUERIES: ClassVar[dict[str, APILookupParams]] = { - ioc_type: _QUERY_DEF for ioc_type in ("ipv4", "ipv6", "dns", "hostname", "url") - } + _QUERIES: ClassVar[dict[str, APILookupParams]] = dict.fromkeys( + ("ipv4", "ipv6", "dns", "hostname", "url"), _QUERY_DEF + ) _REQUIRED_PARAMS: ClassVar[list[str]] = ["API_KEY"] _RISK_MAP: ClassVar[dict[str, ResultSeverity]] = { diff --git a/msticpy/context/tiproviders/result_severity.py b/msticpy/context/tiproviders/result_severity.py index 90a1fb201..48d56be2d 100644 --- a/msticpy/context/tiproviders/result_severity.py +++ b/msticpy/context/tiproviders/result_severity.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Result Severity enumeration.""" + from __future__ import annotations from enum import Enum @@ -19,7 +20,7 @@ @total_ordering -class ResultSeverity(Enum): +class ResultSeverity(Enum): # noqa: PLW1641 """Item report severity.""" # pylint: disable=invalid-name @@ -50,9 +51,7 @@ def parse(cls: type[Self], value: object) -> ResultSeverity: return value if isinstance(value, str) and value.lower() in cls.__members__: return cls[value.lower()] - if isinstance(value, int) and value in [ - v.value for v in cls.__members__.values() - ]: + if isinstance(value, int) and value in [v.value for v in cls.__members__.values()]: return cls(value) return ResultSeverity.unknown diff --git a/msticpy/context/tiproviders/riskiq.py b/msticpy/context/tiproviders/riskiq.py index 5fed85bfc..0cf2c647b 100644 --- a/msticpy/context/tiproviders/riskiq.py +++ b/msticpy/context/tiproviders/riskiq.py @@ -12,10 +12,12 @@ requests per minute for the account type that you have. """ + from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import pandas as pd from typing_extensions import Self @@ -250,9 +252,7 @@ def _parse_result_all_props( "reputation": pt_result.reputation.as_dict, } ti_result["RawResult"] = ti_result["Details"] - ti_result["Result"] = ( - pt_result.summary.total != 0 or pt_result.reputation.score != 0 - ) + ti_result["Result"] = pt_result.summary.total != 0 or pt_result.reputation.score != 0 rep_severity: ResultSeverity = self._severity_rep( pt_result.reputation.classification, @@ -321,14 +321,8 @@ def _set_pivot_timespan( """ changed = False - start = ( - start or self._pivot_get_timespan().start - if self._pivot_get_timespan - else None - ) - end = ( - end or self._pivot_get_timespan().end if self._pivot_get_timespan else None - ) + start = start or self._pivot_get_timespan().start if self._pivot_get_timespan else None + end = end or self._pivot_get_timespan().end if self._pivot_get_timespan else None if ( start and end diff --git a/msticpy/context/tiproviders/ti_http_provider.py b/msticpy/context/tiproviders/ti_http_provider.py index 2d64db138..bf76b5034 100644 --- a/msticpy/context/tiproviders/ti_http_provider.py +++ b/msticpy/context/tiproviders/ti_http_provider.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations from functools import lru_cache @@ -133,9 +134,7 @@ def _run_ti_lookup_query( result, ) except JSONDecodeError: - result[ - "RawResult" - ] = f"""There was a problem parsing results from this lookup: + result["RawResult"] = f"""There was a problem parsing results from this lookup: {response.text}""" result["Result"] = False severity = ResultSeverity.information @@ -150,7 +149,7 @@ def _run_ti_lookup_query( result["Details"] = self._response_message(result["Status"]) return result - @lru_cache(maxsize=256) + @lru_cache(maxsize=256) # noqa: B019 def lookup_ioc( # noqa: PLR0913 self: Self, ioc: str, diff --git a/msticpy/context/tiproviders/ti_provider_base.py b/msticpy/context/tiproviders/ti_provider_base.py index cbaab75bb..b7d2c629e 100644 --- a/msticpy/context/tiproviders/ti_provider_base.py +++ b/msticpy/context/tiproviders/ti_provider_base.py @@ -12,11 +12,13 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import logging from abc import abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import Self diff --git a/msticpy/context/tiproviders/tor_exit_nodes.py b/msticpy/context/tiproviders/tor_exit_nodes.py index 5554e7627..55e37ec07 100644 --- a/msticpy/context/tiproviders/tor_exit_nodes.py +++ b/msticpy/context/tiproviders/tor_exit_nodes.py @@ -12,12 +12,14 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import contextlib +from collections.abc import Iterable from datetime import datetime, timezone from threading import Lock -from typing import Any, ClassVar, Iterable +from typing import Any, ClassVar import httpx import pandas as pd @@ -75,9 +77,7 @@ def _check_and_get_nodelist(cls: type[Self]) -> None: tor_raw_list = resp.content.decode() with cls._cache_lock: node_dict: dict[str, Any] = {"ExitNode": True, "LastStatus": now} - cls._nodelist = { - node: node_dict for node in tor_raw_list.split("\n") - } + cls._nodelist = dict.fromkeys(tor_raw_list.split("\n"), node_dict) cls._last_cached = datetime.now(timezone.utc) @staticmethod diff --git a/msticpy/context/tiproviders/virustotal.py b/msticpy/context/tiproviders/virustotal.py index c5e95d3d4..4ecb42be6 100644 --- a/msticpy/context/tiproviders/virustotal.py +++ b/msticpy/context/tiproviders/virustotal.py @@ -12,6 +12,7 @@ requests per minute for the account type that you have. """ + from __future__ import annotations import datetime as dt @@ -130,7 +131,7 @@ def parse_results(self: Self, response: dict) -> tuple[bool, ResultSeverity, Any if "positives" in result_dict: positives = result_dict.get("positives", 0) - if not isinstance(positives, (int, float)): + if not isinstance(positives, int | float): positives = 0 elif isinstance(positives, str): # sometimes the API returns a string with a number in it diff --git a/msticpy/context/vtlookupv3/__init__.py b/msticpy/context/vtlookupv3/__init__.py index ecb82d330..3bd063112 100644 --- a/msticpy/context/vtlookupv3/__init__.py +++ b/msticpy/context/vtlookupv3/__init__.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """VirusTotal V3 Subpackage.""" + import contextlib from ..._version import VERSION diff --git a/msticpy/context/vtlookupv3/vtfile_behavior.py b/msticpy/context/vtlookupv3/vtfile_behavior.py index 6d10b1924..dd2afe979 100644 --- a/msticpy/context/vtlookupv3/vtfile_behavior.py +++ b/msticpy/context/vtlookupv3/vtfile_behavior.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """VirusTotal File Behavior functions.""" + from __future__ import annotations import logging @@ -30,9 +31,7 @@ try: import vt except ImportError as imp_err: - ERR_MSG: str = ( - "Cannot use this feature without vt-py and vt-graph-api packages installed." - ) + ERR_MSG: str = "Cannot use this feature without vt-py and vt-graph-api packages installed." raise MsticpyImportExtraError( ERR_MSG, title="Error importing VirusTotal modules.", @@ -320,8 +319,7 @@ def _build_process_tree(fb_categories: dict[str, Any]) -> pd.DataFrame: """Top level function to create displayable DataFrame.""" proc_tree_raw: list[dict[str, Any]] = deepcopy(fb_categories["processes_tree"]) procs_created: dict[str, Any] = { - Path(proc).parts[-1].lower(): proc - for proc in fb_categories["processes_created"] + Path(proc).parts[-1].lower(): proc for proc in fb_categories["processes_created"] } si_procs: list[SIProcess] = _extract_processes(proc_tree_raw, procs_created) @@ -371,9 +369,9 @@ def _create_si_proc( """Return an SIProcess Object from a raw VT proc definition.""" name: str = raw_proc["name"] raw_proc["cmd_line"] = name - for proc in procs_created: + for proc, proc_name in procs_created.items(): if name.lower().endswith(proc): - raw_proc["name"] = procs_created[proc] + raw_proc["name"] = proc_name break raw_proc["proc_key"] = raw_proc["process_id"] + "|" + raw_proc["name"] return SIProcess(**raw_proc) @@ -408,7 +406,7 @@ def _try_match_commandlines( and np.isnan(row["cmd_line"]) and row["name"] in cmd ): - procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore + procs_cmd.loc[idx, "cmd_line"] = cmd break for cmd in command_executions: for idx, row in procs_cmd.iterrows(): @@ -418,7 +416,7 @@ def _try_match_commandlines( and Path(row["name"]).stem.lower() in cmd.lower() ): weak_matches += 1 - procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore + procs_cmd.loc[idx, "cmd_line"] = cmd break if weak_matches: diff --git a/msticpy/context/vtlookupv3/vtlookup.py b/msticpy/context/vtlookupv3/vtlookup.py index e32cf8730..5d670c56f 100644 --- a/msticpy/context/vtlookupv3/vtlookup.py +++ b/msticpy/context/vtlookupv3/vtlookup.py @@ -19,6 +19,7 @@ - IPv4 Address """ + from __future__ import annotations import contextlib @@ -430,9 +431,7 @@ def _lookup_ioc_type( # 2. Or we have reached the end of our row iteration # AND # 3. The batch is not empty - if ( - len(obs_batch) == vt_param.batch_size or row_num == row_count - ) and obs_batch: + if (len(obs_batch) == vt_param.batch_size or row_num == row_count) and obs_batch: obs_submit: str = vt_param.batch_delimiter.join(obs_batch) self._print_status( @@ -505,11 +504,7 @@ def _parse_vt_results( # noqa:PLR0913 with contextlib.suppress(JSONDecodeError, TypeError): vt_results = json.loads(vt_results, strict=False) - if ( - isinstance(vt_results, list) - and vt_param is not None - and vt_param.batch_size > 1 - ): + if isinstance(vt_results, list) and vt_param is not None and vt_param.batch_size > 1: # multiple results results_to_parse = vt_results elif isinstance(vt_results, dict): @@ -562,9 +557,7 @@ def _parse_vt_results( # noqa:PLR0913 ] else: df_dict_vtresults["Observable"] = observables[result_idx] - df_dict_vtresults["SourceIndex"] = source_row_index[ - observables[result_idx] - ] + df_dict_vtresults["SourceIndex"] = source_row_index[observables[result_idx]] new_results: pd.DataFrame = pd.concat( objs=[self.results, df_dict_vtresults], @@ -637,9 +630,7 @@ def _parse_single_result( df_dict_vtresults["ResolvedIPs"] = ", ".join(item_list) if "detected_urls" in results_dict: item_list = [ - item["url"] - for item in results_dict["detected_urls"] - if "url" in item + item["url"] for item in results_dict["detected_urls"] if "url" in item ] df_dict_vtresults["DetectedUrls"] = ", ".join(item_list) # positives are listed per detected_url so we need to @@ -756,9 +747,7 @@ def _check_duplicate_submission( return DuplicateStatus(is_dup=False, status="ok") # Note duplicate var here can be multiple rows of past results - duplicate: pd.DataFrame = self.results[ - self.results["Observable"] == observable - ].copy() + duplicate: pd.DataFrame = self.results[self.results["Observable"] == observable].copy() # if this is a file hash we should check for previous results in # all of the hash columns if duplicate.shape[0] == 0 and ioc_type in [ @@ -766,9 +755,7 @@ def _check_duplicate_submission( "sha1_hash", "sh256_hash", ]: - dup_query = ( - "MD5 == @observable or SHA1 == @observable or SHA256 == @observable" - ) + dup_query = "MD5 == @observable or SHA1 == @observable or SHA256 == @observable" duplicate = self.results.query(dup_query).copy() # In these cases we want to set the observable to the source value # but keep the rest of the results @@ -778,9 +765,7 @@ def _check_duplicate_submission( # if we found a duplicate so add the copies of the duplicated requests # to the results if duplicate.shape[0] > 0: - original_indices: list = [ - v[0] for v in duplicate[["SourceIndex"]].to_numpy() - ] + original_indices: list = [v[0] for v in duplicate[["SourceIndex"]].to_numpy()] duplicate["SourceIndex"] = source_index duplicate["Status"] = "Duplicate" new_results: pd.DataFrame = pd.concat( @@ -826,7 +811,7 @@ def _add_invalid_input_result( new_row["Status"] = status new_row["SourceIndex"] = source_idx new_results: pd.DataFrame = self.results.append( - new_row.to_dict(), # type: ignore[operator] + new_row.to_dict(), ignore_index=True, ) @@ -899,9 +884,7 @@ def _get_vt_api_url(cls: type[Self], api_type: str) -> str: @classmethod def _get_supported_vt_ioc_types(cls: type[VTLookup]) -> list[str]: """Return the subset of IoC types supported by VT.""" - return [ - t for t in cls._SUPPORTED_INPUT_TYPES if cls._VT_TYPE_MAP[t] is not None - ] + return [t for t in cls._SUPPORTED_INPUT_TYPES if cls._VT_TYPE_MAP[t] is not None] def _print_status(self: Self, message: str, verbosity_level: int) -> None: """ diff --git a/msticpy/context/vtlookupv3/vtlookupv3.py b/msticpy/context/vtlookupv3/vtlookupv3.py index abdc5c255..1d30ab511 100644 --- a/msticpy/context/vtlookupv3/vtlookupv3.py +++ b/msticpy/context/vtlookupv3/vtlookupv3.py @@ -5,8 +5,9 @@ import asyncio import logging +from collections.abc import Coroutine, Iterable from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Iterable +from typing import TYPE_CHECKING, Any, ClassVar import pandas as pd from IPython.core.display import HTML @@ -88,7 +89,7 @@ class VTObjectProperties(Enum): def _ensure_eventloop(*, force_nest_asyncio: bool = False) -> asyncio.AbstractEventLoop: """Ensure that we have an event loop available.""" try: - event_loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() except RuntimeError: # Generate an event loop if there isn't any. event_loop = asyncio.new_event_loop() @@ -277,10 +278,8 @@ async def _lookup_ioc_async( """ if VTEntityType(vt_type) not in self._SUPPORTED_VT_TYPES: - error_msg: str = ( - f"Property type {vt_type} not supported" - "Valid types are" - ", ".join(x.value for x in VTEntityType.__members__.values()) + error_msg: str = f"Property type {vt_type} not supportedValid types are, ".join( + x.value for x in VTEntityType.__members__.values() ) raise KeyError(error_msg) @@ -379,7 +378,7 @@ async def _lookup_iocs_async( observable_type, all_props=all_props, ) - for observable, observable_type in zip(observables_list, types_list) + for observable, observable_type in zip(observables_list, types_list, strict=False) ] dfs: list[pd.DataFrame] = await asyncio.gather(*dfs_futures) @@ -526,8 +525,7 @@ async def _lookup_ioc_relationships_async( # pylint: disable=too-many-locals #n add_columns = pd.DataFrame( { ColumnNames.SOURCE.value: [observable] * rows, - ColumnNames.SOURCE_TYPE.value: [VTEntityType(vt_type).value] - * rows, + ColumnNames.SOURCE_TYPE.value: [VTEntityType(vt_type).value] * rows, ColumnNames.RELATIONSHIP_TYPE.value: [relationship] * rows, }, ) @@ -710,7 +708,7 @@ async def _lookup_iocs_relationships_async( # noqa: PLR0913 limit=limit, all_props=all_props, ) - for observable, observable_type in zip(observables_list, types_list) + for observable, observable_type in zip(observables_list, types_list, strict=False) ] dfs: list[pd.DataFrame] = await asyncio.gather(*dfs_futures) @@ -868,9 +866,9 @@ def get_object(self: Self, vt_id: str, vt_type: str) -> pd.DataFrame: # pylint: disable=no-member error_msg: str = ( - f"Property type {vt_type} not supported. " - "Valid types are: " - ", ".join(x.value for x in VTEntityType.__members__.values()) + f"Property type {vt_type} not supported. Valid types are: , ".join( + x.value for x in VTEntityType.__members__.values() + ) ) raise KeyError(error_msg) @@ -962,9 +960,7 @@ def search( params={"query": query}, limit=limit, ) - response_list: list[dict[str, Any]] = [ - item.to_dict() for item in response_itr - ] + response_list: list[dict[str, Any]] = [item.to_dict() for item in response_itr] except vt.APIError as api_err: error_msg: str = ( f"The provided query returned 0 results because of an APIError: {api_err}" @@ -1051,9 +1047,7 @@ def _extract_response(self: Self, response_list: list) -> pd.DataFrame: response_rows = [] for response_item in response_list: # flatten nested dictionary and append id, type values - response_item_df = pd.json_normalize( - response_item["attributes"], max_level=0 - ) + response_item_df = pd.json_normalize(response_item["attributes"], max_level=0) response_item_df["id"] = response_item["id"] response_item_df["type"] = response_item["type"] @@ -1099,9 +1093,7 @@ def relationships_to_graph( # Create nodes DF, with source and target sources_df = ( - concatenated_df.groupby(ColumnNames.SOURCE.value)[ - ColumnNames.SOURCE_TYPE.value - ] + concatenated_df.groupby(ColumnNames.SOURCE.value)[ColumnNames.SOURCE_TYPE.value] .first() .reset_index() .rename( @@ -1113,9 +1105,7 @@ def relationships_to_graph( ) target_df = ( - concatenated_df.groupby(ColumnNames.TARGET.value)[ - ColumnNames.TARGET_TYPE.value - ] + concatenated_df.groupby(ColumnNames.TARGET.value)[ColumnNames.TARGET_TYPE.value] .first() .reset_index() .rename( @@ -1192,7 +1182,7 @@ def _item_not_found_df( not_found_dict["status"] = "Unsupported type" else: not_found_dict.update( - {key: "Not found" for key in cls._BASIC_PROPERTIES_PER_TYPE[vte_type]}, + dict.fromkeys(cls._BASIC_PROPERTIES_PER_TYPE[vte_type], "Not found"), ) return pd.DataFrame([not_found_dict]) @@ -1225,9 +1215,7 @@ def _get_vt_api_key() -> str | None: def timestamps_to_utcdate(data: pd.DataFrame) -> pd.DataFrame: """Replace Unix timestamps in VT data with Py/pandas Timestamp.""" columns: pd.Index = data.columns - for date_col in ( - col for col in columns if isinstance(col, str) and col.endswith("_date") - ): + for date_col in (col for col in columns if isinstance(col, str) and col.endswith("_date")): data = ( data.assign(pd_data=pd.to_datetime(data[date_col], unit="s", utc=True)) .drop(columns=date_col) diff --git a/msticpy/data/__init__.py b/msticpy/data/__init__.py index 44ef6502e..1ca32baa8 100644 --- a/msticpy/data/__init__.py +++ b/msticpy/data/__init__.py @@ -20,6 +20,7 @@ - uploaders - loaders for some data services. """ + from .._version import VERSION # from ..common.exceptions import MsticpyImportExtraError diff --git a/msticpy/data/azure/__init__.py b/msticpy/data/azure/__init__.py deleted file mode 100644 index bc0bc1240..000000000 --- a/msticpy/data/azure/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module data.azure has moved. - -See :py:mod:`msticpy.context.azure` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - -# pylint: disable=unused-import -from ...context.azure.azure_data import AzureData # noqa: F401 -from ...context.azure.sentinel_core import MicrosoftSentinel # noqa: F401 - -WARN_MSSG = ( - "This module has moved to msticpy.context.azure\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/azure/azure_blob_storage.py b/msticpy/data/azure/azure_blob_storage.py deleted file mode 100644 index 4d3b11f4b..000000000 --- a/msticpy/data/azure/azure_blob_storage.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module azure_blob_storage.py has moved. - -See :py:mod:`msticpy.data.azure.azure_blob_storage` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..storage.azure_blob_storage import * - -WARN_MSSG = ( - "This module has moved to msticpy.context.azure.azure_blob_storage\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/azure/azure_data.py b/msticpy/data/azure/azure_data.py deleted file mode 100644 index 23744fb7d..000000000 --- a/msticpy/data/azure/azure_data.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module azure_data.py has moved. - -See :py:mod:`msticpy.context.azure.azure_data` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ...context.azure.azure_data import * - -WARN_MSSG = ( - "This module has moved to msticpy.context.azure.azure_data\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/azure_blob_storage.py b/msticpy/data/azure_blob_storage.py deleted file mode 100644 index 55d9501d4..000000000 --- a/msticpy/data/azure_blob_storage.py +++ /dev/null @@ -1,22 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module azure_blob_storage.py has moved. - -See :py:mod:`msticpy.data.storage.azure_blob_storage` -""" -import warnings - -# flake8: noqa: F403, F401 -# pylint: disable=unused-import -from ..data.storage.azure_blob_storage import AzureBlobStorage - -WARN_MSSG = ( - "This module has moved to msticpy.data.storage.azure_blob_storage\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/azure_data.py b/msticpy/data/azure_data.py deleted file mode 100644 index f84381ece..000000000 --- a/msticpy/data/azure_data.py +++ /dev/null @@ -1,22 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module azure_data.py has moved. - -See :py:mod:`msticpy.context.azure.azure_data` -""" -import warnings - -# flake8: noqa: F403, F401 -# pylint: disable=unused-import -from ..context.azure.azure_data import AzureData - -WARN_MSSG = ( - "This module has moved to msticpy.context.azure.azure_data\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/azure_sentinel.py b/msticpy/data/azure_sentinel.py deleted file mode 100644 index 246e489d7..000000000 --- a/msticpy/data/azure_sentinel.py +++ /dev/null @@ -1,22 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module azure_sentinel.py has moved. - -See :py:mod:`msticpy.context.azure.azure_sentinel` -""" -import warnings - -# flake8: noqa: F403, F401 -# pylint: disable=unused-import -from ..context.azure.sentinel_core import MicrosoftSentinel as AzureSentinel - -WARN_MSSG = ( - "This module has moved to msticpy.context.azure.sentinel_core\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/core/data_providers.py b/msticpy/data/core/data_providers.py index d279da5d5..87ad7f6fa 100644 --- a/msticpy/data/core/data_providers.py +++ b/msticpy/data/core/data_providers.py @@ -4,12 +4,14 @@ # license information. # -------------------------------------------------------------------------- """Data provider loader.""" + from __future__ import annotations import logging +from collections.abc import Iterable from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -96,10 +98,12 @@ def __init__( """ # import at runtime to prevent circular import # pylint: disable=import-outside-toplevel, cyclic-import - from ...init.pivot_init.pivot_data_queries import add_data_queries_to_entities + from ...init.pivot_init.pivot_data_queries import ( # noqa: PLC0415 + add_data_queries_to_entities, + ) # pylint: enable=import-outside-toplevel - setattr(self.__class__, "_add_pivots", add_data_queries_to_entities) + self.__class__._add_pivots = add_data_queries_to_entities # type: ignore[attr-defined] data_environment, self.environment_name = QueryProvider._check_environment( data_environment, @@ -107,22 +111,17 @@ def __init__( self._driver_kwargs: dict[str, Any] = kwargs.copy() if driver is None: - self.driver_class: type[DriverBase] = drivers.import_driver( - data_environment - ) + self.driver_class: type[DriverBase] = drivers.import_driver(data_environment) if issubclass(self.driver_class, DriverBase): driver = self.driver_class(data_environment=data_environment, **kwargs) else: - err_msg: str = ( - f"Could not find suitable data provider for {data_environment}" - ) + err_msg: str = f"Could not find suitable data provider for {data_environment}" raise LookupError(err_msg) else: self.driver_class = driver.__class__ # allow the driver to override the data environment used for selecting queries self.environment_name = ( - driver.get_driver_property(DriverProps.EFFECTIVE_ENV) - or self.environment_name + driver.get_driver_property(DriverProps.EFFECTIVE_ENV) or self.environment_name ) logger.info("Using data environment %s", self.environment_name) logger.info("Driver class: %s", self.driver_class.__name__) @@ -170,7 +169,9 @@ def _check_environment( elif isinstance(data_environment, DataEnvironment): environment_name = data_environment.name else: - err_msg = f"Unknown data environment type {data_environment} ({type(data_environment)})" + err_msg = ( + f"Unknown data environment type {data_environment} ({type(data_environment)})" + ) raise TypeError(err_msg) return data_environment, environment_name @@ -216,9 +217,7 @@ def connect(self: Self, connection_str: str | None = None, **kwargs) -> None: # Add any built-in or dynamically retrieved queries from driver if self._query_provider.has_driver_queries: logger.info("Adding driver queries to provider") - driver_queries: Iterable[dict[str, Any]] = ( - self._query_provider.driver_queries - ) + driver_queries: Iterable[dict[str, Any]] = self._query_provider.driver_queries self._add_driver_queries(queries=driver_queries) refresh_query_funcs = True @@ -432,9 +431,7 @@ def _get_query_options( query_options: dict[str, Any] = kwargs.pop("query_options", {}) if not query_options: # Any kwargs left over we send to the query provider driver - query_options = { - key: val for key, val in kwargs.items() if key not in params - } + query_options = {key: val for key, val in kwargs.items() if key not in params} query_options["time_span"] = { "start": params.get("start"), "end": params.get("end"), diff --git a/msticpy/data/core/data_query_reader.py b/msticpy/data/core/data_query_reader.py index ce84158de..96448b834 100644 --- a/msticpy/data/core/data_query_reader.py +++ b/msticpy/data/core/data_query_reader.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Data query definition reader.""" + import logging +from collections.abc import Iterable from itertools import chain from pathlib import Path -from typing import Any, Dict, Iterable, Tuple +from typing import Any import yaml @@ -48,7 +50,7 @@ def find_yaml_files(source_path: str, recursive: bool = True) -> Iterable[Path]: yield file_path -def read_query_def_file(query_file: str) -> Tuple[Dict, Dict, Dict]: +def read_query_def_file(query_file: str) -> tuple[dict, dict, dict]: """ Read a yaml data query definition file. @@ -67,7 +69,7 @@ def read_query_def_file(query_file: str) -> Tuple[Dict, Dict, Dict]: """ data_map = None - with open(query_file, "r", encoding="utf-8") as f_handle: + with open(query_file, encoding="utf-8") as f_handle: # use safe_load instead load data_map = yaml.safe_load(f_handle) @@ -90,7 +92,7 @@ def read_query_def_file(query_file: str) -> Tuple[Dict, Dict, Dict]: return sources, defaults, metadata -def validate_query_defs(query_def_dict: Dict[str, Any]) -> bool: +def validate_query_defs(query_def_dict: dict[str, Any]) -> bool: """ Validate content of query definition. @@ -126,7 +128,7 @@ def validate_query_defs(query_def_dict: Dict[str, Any]) -> bool: return True -def _validate_data_categories(query_def_dict: Dict): +def _validate_data_categories(query_def_dict: dict): if ( "data_environments" not in query_def_dict["metadata"] or not query_def_dict["metadata"]["data_environments"] diff --git a/msticpy/data/core/param_extractor.py b/msticpy/data/core/param_extractor.py index 5bf37d2b8..ac27987d9 100644 --- a/msticpy/data/core/param_extractor.py +++ b/msticpy/data/core/param_extractor.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Parameter extractor helper functions for use with IPython/Juptyer queries.""" -from typing import Any, Dict, List, Mapping, Tuple + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -18,7 +20,7 @@ @export def extract_query_params( query_source: QuerySource, *args, **kwargs -) -> Tuple[Dict[str, Any], List[str]]: +) -> tuple[dict[str, Any], list[str]]: """ Get the parameters needed for the query. @@ -50,7 +52,7 @@ def extract_query_params( # at least the required params plus any that are extracted from args and # kwargs and have been added dynamically. req_param_names = query_source.required_params.keys() - req_params: Dict[str, Any] = {param: None for param in req_param_names} + req_params: dict[str, Any] = dict.fromkeys(req_param_names) # try to retrieve any parameters as attributes of the args objects _get_object_params(args, all_params, req_params) @@ -65,14 +67,12 @@ def extract_query_params( # Get the names of any params that were required but we didn't # find a value for - missing_params = [ - p_name for p_name, p_value in req_params.items() if p_value is None - ] + missing_params = [p_name for p_name, p_value in req_params.items() if p_value is None] return req_params, missing_params def _get_object_params( - args: Tuple[Any, ...], params: Mapping[str, Any], req_params: Dict[str, Any] + args: tuple[Any, ...], params: Mapping[str, Any], req_params: dict[str, Any] ): """ Get params from attributes of arg objects. @@ -89,7 +89,7 @@ def _get_object_params( """ remaining_params = list(params.keys()) for arg_object in args: - if isinstance(arg_object, (str, int, float, bool)): + if isinstance(arg_object, str | int | float | bool): # ignore some common primitive types continue for param in remaining_params: diff --git a/msticpy/data/core/query_container.py b/msticpy/data/core/query_container.py index a988edde5..d812e185e 100644 --- a/msticpy/data/core/query_container.py +++ b/msticpy/data/core/query_container.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Query hierarchy attribute class.""" + from functools import partial from ..._version import VERSION diff --git a/msticpy/data/core/query_defns.py b/msticpy/data/core/query_defns.py index 4d8a7dd58..46a4314eb 100644 --- a/msticpy/data/core/query_defns.py +++ b/msticpy/data/core/query_defns.py @@ -4,9 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Query helper definitions.""" + from abc import ABC, abstractmethod from enum import Enum -from typing import Union from ..._version import VERSION from ...common.utility import export @@ -41,7 +41,7 @@ class DataFamily(Enum): Prismacloud = 21 @classmethod - def parse(cls, value: Union[str, int]) -> "DataFamily": + def parse(cls, value: str | int) -> "DataFamily": """ Convert string or int to enum. @@ -123,7 +123,7 @@ class DataEnvironment(Enum): MSSentinelSearch = 25 @classmethod - def parse(cls, value: Union[str, int]) -> "DataEnvironment": + def parse(cls, value: str | int) -> "DataEnvironment": """ Convert string or int to enum. diff --git a/msticpy/data/core/query_provider_connections_mixin.py b/msticpy/data/core/query_provider_connections_mixin.py index bfdf3a985..ae74d8499 100644 --- a/msticpy/data/core/query_provider_connections_mixin.py +++ b/msticpy/data/core/query_provider_connections_mixin.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Query Provider additional connection methods.""" + from __future__ import annotations import asyncio @@ -193,8 +194,7 @@ def _exec_additional_connections( """ # Add the initial connection query_tasks: dict[str, partial[pd.DataFrame | str | None]] = { - self._query_provider.current_connection - or "0": partial( + self._query_provider.current_connection or "0": partial( self._query_provider.query, query, **kwargs, @@ -287,19 +287,16 @@ def _exec_split_query( logger.warning("Cannot split a query with no 'start' and 'end' parameters") return None - split_queries: dict[tuple[datetime, datetime], str] = ( - self._create_split_queries( - query_source=query_source, - query_params=query_params, - start=start, - end=end, - split_by=split_by, - ) + split_queries: dict[tuple[datetime, datetime], str] = self._create_split_queries( + query_source=query_source, + query_params=query_params, + start=start, + end=end, + split_by=split_by, ) if debug: return "\n\n".join( - f"{start}-{end}\n{query}" - for (start, end), query in split_queries.items() + f"{start}-{end}\n{query}" for (start, end), query in split_queries.items() ) query_tasks: dict[str, partial[pd.DataFrame | str | None]] = ( @@ -444,7 +441,7 @@ async def _exec_queries_threaded( ) else: task_iter = asyncio.as_completed(thread_tasks.values()) - ids_and_tasks = dict(zip(thread_tasks, task_iter)) + ids_and_tasks = dict(zip(thread_tasks, task_iter, strict=False)) for query_id, thread_task in ids_and_tasks.items(): try: result: pd.DataFrame | str | None = await thread_task @@ -461,7 +458,7 @@ async def _exec_queries_threaded( failed_tasks_ids.append(query_id) # Sort the results by the order of the tasks - results = [result for _, result in sorted(zip(thread_tasks, results))] + results = [result for _, result in sorted(zip(thread_tasks, results, strict=False))] if retry and failed_tasks_ids: failed_results: pd.DataFrame = ( @@ -512,7 +509,7 @@ def _calc_split_ranges( # get duplicates in these cases ranges: list[tuple[datetime, datetime]] = [ (s_time, e_time - pd.Timedelta("1ns")) - for s_time, e_time in zip(s_ranges, e_ranges) + for s_time, e_time in zip(s_ranges, e_ranges, strict=False) ] # Since the generated time ranges are based on deltas from 'start' diff --git a/msticpy/data/core/query_provider_utils_mixin.py b/msticpy/data/core/query_provider_utils_mixin.py index ea1d2b4b2..bba53241f 100644 --- a/msticpy/data/core/query_provider_utils_mixin.py +++ b/msticpy/data/core/query_provider_utils_mixin.py @@ -4,9 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Query Provider mixin methods.""" + import re from collections import abc -from typing import Dict, Iterable, List, NamedTuple, Optional, Pattern, Protocol, Union +from collections.abc import Iterable +from re import Pattern +from typing import NamedTuple, Protocol from ..._version import VERSION from ...common.utility.package import delayed_import @@ -43,8 +46,8 @@ class QueryParam(NamedTuple): name: str data_type: str - description: Optional[str] = None - default: Optional[str] = None + description: str | None = None + default: str | None = None # pylint: disable=super-init-not-called @@ -80,7 +83,7 @@ def connection_string(self) -> str: return self._query_provider.current_connection @property - def schema(self) -> Dict[str, Dict]: + def schema(self) -> dict[str, dict]: """ Return current data schema of connection. @@ -93,7 +96,7 @@ def schema(self) -> Dict[str, Dict]: return self._query_provider.schema @property - def schema_tables(self) -> List[str]: + def schema_tables(self) -> list[str]: """ Return list of tables in the data schema of the connection. @@ -106,7 +109,7 @@ def schema_tables(self) -> List[str]: return list(self._query_provider.schema.keys()) @property - def instance(self) -> Optional[str]: + def instance(self) -> str | None: """ Return instance name, if any for provider. @@ -137,7 +140,7 @@ def driver_help(self): print(self._query_provider.__doc__) @classmethod - def list_data_environments(cls) -> List[str]: + def list_data_environments(cls) -> list[str]: """ Return list of current data environments. @@ -149,13 +152,11 @@ def list_data_environments(cls) -> List[str]: """ # pylint: disable=not-an-iterable return [ - de - for de in dir(DataEnvironment) - if de != "Unknown" and not de.startswith("_") + de for de in dir(DataEnvironment) if de != "Unknown" and not de.startswith("_") ] # pylint: enable=not-an-iterable - def list_queries(self, substring: Optional[str] = None) -> List[str]: + def list_queries(self, substring: str | None = None) -> list[str]: """ Return list of family.query in the store. @@ -174,7 +175,7 @@ def list_queries(self, substring: Optional[str] = None) -> List[str]: if substring: return list( filter( - lambda x: substring in x.lower(), # type: ignore + lambda x: substring in x.lower(), self.query_store.query_names, ) ) @@ -182,11 +183,11 @@ def list_queries(self, substring: Optional[str] = None) -> List[str]: def search( self, - search: Union[str, Iterable[str]] = None, - table: Union[str, Iterable[str]] = None, - param: Union[str, Iterable[str]] = None, + search: str | Iterable[str] | None = None, + table: str | Iterable[str] | None = None, + param: str | Iterable[str] | None = None, ignore_case: bool = True, - ) -> List[str]: + ) -> list[str]: """ Search queries for match properties. @@ -227,7 +228,7 @@ def search( glob_searches = _normalize_to_regex(search, ignore_case) table_searches = _normalize_to_regex(table, ignore_case) param_searches = _normalize_to_regex(param, ignore_case) - search_hits: List[str] = [] + search_hits: list[str] = [] for query, search_data in self.query_store.search_items.items(): glob_match = (not glob_searches) or any( re.search(term, prop) @@ -292,9 +293,9 @@ def add_custom_query( self, name: str, query: str, - family: Union[str, Iterable[str]], - description: Optional[str] = None, - parameters: Optional[Iterable[QueryParam]] = None, + family: str | Iterable[str], + description: str | None = None, + parameters: Iterable[QueryParam] | None = None, ): """ Add a custom function to the provider. @@ -358,16 +359,14 @@ def add_custom_query( "parameters": param_dict, } metadata = {"data_families": [family] if isinstance(family, str) else family} - query_source = QuerySource( - name=name, source=source, defaults={}, metadata=metadata - ) + query_source = QuerySource(name=name, source=source, defaults={}, metadata=metadata) self.query_store.add_data_source(query_source) self._add_query_functions() def _normalize_to_regex( - search_term: Union[str, Iterable[str], None], ignore_case: bool -) -> List[Pattern[str]]: + search_term: str | Iterable[str] | None, ignore_case: bool +) -> list[Pattern[str]]: """Return iterable or str search term as list of compiled reg expressions.""" if not search_term: return [] diff --git a/msticpy/data/core/query_source.py b/msticpy/data/core/query_source.py index 97a01278c..2836e8eaf 100644 --- a/msticpy/data/core/query_source.py +++ b/msticpy/data/core/query_source.py @@ -4,24 +4,29 @@ # license information. # -------------------------------------------------------------------------- """Intake kql driver.""" + from __future__ import annotations import json import re +from collections.abc import Callable # from collections import ChainMap from datetime import datetime, timedelta, timezone from json.decoder import JSONDecodeError from numbers import Number -from typing import Any, Callable +from typing import TYPE_CHECKING, Any -from dateutil.parser import ParserError, parse # type: ignore +from dateutil.parser import ParserError, parse from dateutil.relativedelta import relativedelta from ..._version import VERSION from ...common.utility import collapse_dicts from .query_defns import Formatters +if TYPE_CHECKING: + from .query_store import QueryStore + __version__ = VERSION __author__ = "Ian Hellen" @@ -94,7 +99,7 @@ def __init__( self._source: dict[str, Any] = source or {} self.defaults: dict[str, Any] = defaults or {} self._global_metadata: dict[str, Any] = dict(metadata) if metadata else {} - self.query_store: "QueryStore" | None = None # type: ignore # noqa: F821 + self.query_store: QueryStore | None = None # consolidate source metadata - source-specific # overrides global @@ -175,9 +180,7 @@ def default_params(self) -> dict[str, dict]: """ return { - p_key: p_props - for p_key, p_props in self.params.items() - if "default" in p_props + p_key: p_props for p_key, p_props in self.params.items() if "default" in p_props } @property @@ -241,14 +244,10 @@ def create_query(self, formatters: dict[str, Callable] = None, **kwargs) -> str: parameter defaults (see `default_params` property). """ - param_dict = { - name: value.get("default", None) for name, value in self.params.items() - } + param_dict = {name: value.get("default", None) for name, value in self.params.items()} param_dict.update(self.resolve_param_aliases(kwargs)) - missing_params = { - name: value for name, value in param_dict.items() if value is None - } + missing_params = {name: value for name, value in param_dict.items() if value is None} if missing_params: raise ValueError( "These required parameters were not set: ", f"{missing_params.keys()}" @@ -280,9 +279,7 @@ def _format_parameter(self, p_name, param_dict, param_settings, formatters): if fmt_template: # custom formatting template in the query definition param_dict[p_name] = fmt_template.format(param_dict[p_name]) - elif param_settings["type"] == "datetime" and isinstance( - param_dict[p_name], datetime - ): + elif param_settings["type"] == "datetime" and isinstance(param_dict[p_name], datetime): # format datetime using driver formatter or default formatter if formatters and Formatters.DATETIME in formatters: param_dict[p_name] = formatters[Formatters.DATETIME](param_dict[p_name]) @@ -301,7 +298,7 @@ def _convert_datetime(self, param_value: Any) -> datetime: if isinstance(param_value, Number): # datetime specified as a number - we # interpret this as an offset from utcnow - return datetime.now(tz=timezone.utc) + timedelta( # type: ignore + return datetime.now(tz=timezone.utc) + timedelta( param_value # type: ignore ) try: @@ -335,16 +332,10 @@ def resolve_param_aliases(self, param_dict: dict[str, Any]) -> dict[str, Any]: def _get_aliased_param(self, alias: str) -> str | None: """Return first parameter with a matching alias.""" aliased_params = { - p_name: p_prop - for p_name, p_prop in self.params.items() - if "aliases" in p_prop + p_name: p_prop for p_name, p_prop in self.params.items() if "aliases" in p_prop } return next( - ( - param - for param, props in aliased_params.items() - if alias in props["aliases"] - ), + (param for param, props in aliased_params.items() if alias in props["aliases"]), None, ) @@ -379,9 +370,7 @@ def _calc_timeoffset(cls, time_offset: str) -> datetime: # unit was specified units = RD_UNIT_MAP.get(round_item or "d", "days") # expand dict to args for relativedelta - result_date = result_date + relativedelta( - **({units: +1}) # type: ignore - ) + result_date = result_date + relativedelta(**({units: +1})) # type: ignore[arg-type] return result_date @staticmethod @@ -393,14 +382,12 @@ def _parse_timedelta(time_range: str = "0") -> timedelta: if not m_time or "value" not in m_time.groupdict(): return timedelta(0) tm_val = int(m_time.groupdict()["sign"] + m_time.groupdict()["value"]) - tm_unit = ( - m_time.groupdict()["unit"].lower() if m_time.groupdict()["unit"] else "d" - ) + tm_unit = m_time.groupdict()["unit"].lower() if m_time.groupdict()["unit"] else "d" # Use relative delta to build the timedelta based on the units # in the time range expression unit_param = RD_UNIT_MAP.get(tm_unit, "days") # expand dict to args for relativedelta - return relativedelta(**({unit_param: tm_val})) # type: ignore + return relativedelta(**({unit_param: tm_val})) # type: ignore[arg-type,return-value] @staticmethod def _parse_param_list(param_value: str | list) -> list[Any]: @@ -460,8 +447,8 @@ def create_doc_string(self) -> str: def_value = None param_block.extend( ( - f'{p_name}: {p_props.get("type", "Any")}{optional}', - f' {p_props.get("description", "no description")}', + f"{p_name}: {p_props.get('type', 'Any')}{optional}", + f" {p_props.get('description', 'no description')}", ) ) @@ -501,10 +488,7 @@ def validate(self) -> tuple[bool, list[str]]: ) valid_failures.append(msg) if not self._query: - msg = ( - f'Source {self.name} does not have "query" property ' - + "in args element." - ) + msg = f'Source {self.name} does not have "query" property ' + "in args element." valid_failures.append(msg) # Now get the query and the parameter definitions from the source and diff --git a/msticpy/data/core/query_store.py b/msticpy/data/core/query_store.py index b5ff6e834..f3d4cb7e5 100644 --- a/msticpy/data/core/query_store.py +++ b/msticpy/data/core/query_store.py @@ -4,13 +4,15 @@ # license information. # -------------------------------------------------------------------------- """QueryStore class - holds a collection of QuerySources.""" + from __future__ import annotations import logging from collections import defaultdict +from collections.abc import Callable, Iterable from functools import cached_property from os import path -from typing import Any, Callable, Iterable +from typing import Any from typing_extensions import Self @@ -201,9 +203,7 @@ def add_query( src_dict = {"args": {"query": query}, "description": description or name} md_dict = {"data_families": query_paths} - query_source = QuerySource( - name=name, source=src_dict, defaults={}, metadata=md_dict - ) + query_source = QuerySource(name=name, source=src_dict, defaults={}, metadata=md_dict) self.add_data_source(query_source) def import_file(self: Self, query_file: str) -> None: @@ -225,18 +225,14 @@ def import_file(self: Self, query_file: str) -> None: try: sources, defaults, metadata = read_query_def_file(query_file) except ValueError: - logger.warning( - "%sis not a valid query definition file - skipping.", query_file - ) + logger.warning("%sis not a valid query definition file - skipping.", query_file) return for source_name, source in sources.items(): new_source = QuerySource(source_name, source, defaults, metadata) self.add_data_source(new_source) - def apply_query_filter( - self: Self, query_filter: Callable[[QuerySource], bool] - ) -> None: + def apply_query_filter(self: Self, query_filter: Callable[[QuerySource], bool]) -> None: """ Apply a filter to the query sources. @@ -251,13 +247,13 @@ def apply_query_filter( source.show = query_filter(source) # pylint: disable=too-many-locals - @classmethod # noqa: MC0001 - def import_files( # noqa: MC0001 + @classmethod + def import_files( cls, source_path: list, recursive: bool = True, driver_query_filter: dict[str, set[str]] | None = None, - ) -> dict[str, "QueryStore"]: + ) -> dict[str, QueryStore]: """ Import multiple query definition files from directory path. @@ -294,9 +290,7 @@ def import_files( # noqa: MC0001 try: sources, defaults, metadata = read_query_def_file(str(file_path)) except ValueError: - print( - f"{file_path} is not a valid query definition file - skipping." - ) + print(f"{file_path} is not a valid query definition file - skipping.") continue for env_value in metadata.get("data_environments", []): @@ -312,9 +306,7 @@ def import_files( # noqa: MC0001 if environment_name not in env_stores: env_stores[environment_name] = cls(environment=environment_name) for source_name, source in sources.items(): - new_source = QuerySource( - source_name, source, defaults, metadata - ) + new_source = QuerySource(source_name, source, defaults, metadata) if not driver_query_filter or ( driver_query_filter and _matches_driver_filter(new_source, driver_query_filter) @@ -326,7 +318,7 @@ def get_query( self: Self, query_name: str, query_path: str | DataFamily | None = None, - ) -> "QuerySource": + ) -> QuerySource: """ Return query with name `data_family` and `query_name`. @@ -352,9 +344,7 @@ def get_query( if query_container in self.data_families: query_path = query_container elif query_path: - query_container = ".".join( - [query_path, query_container] # type: ignore - ) + query_container = ".".join([query_path, query_container]) if query_container in self.data_families: query_path = query_container query = self.data_families.get(query_path, {}).get(query_name) # type: ignore diff --git a/msticpy/data/core/query_template.py b/msticpy/data/core/query_template.py index ba6ef327f..609298723 100644 --- a/msticpy/data/core/query_template.py +++ b/msticpy/data/core/query_template.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """MSTICPy query template definition.""" + from dataclasses import field -from typing import Any, Dict, List, Optional, Union +from typing import Any from pydantic.dataclasses import dataclass @@ -22,16 +23,16 @@ class QueryMetadata: version: int description: str - data_environments: List[str] - data_families: List[str] - database: Optional[str] = None - cluster: Optional[str] = None - clusters: Optional[List[str]] = None - cluster_groups: Optional[List[str]] = None - tags: List[str] = field(default_factory=list) - data_source: Optional[str] = None - aliases: Optional[Union[str, List[str]]] = None - query_macros: Optional[Dict[str, Any]] = None + data_environments: list[str] + data_families: list[str] + database: str | None = None + cluster: str | None = None + clusters: list[str] | None = None + cluster_groups: list[str] | None = None + tags: list[str] = field(default_factory=list) + data_source: str | None = None + aliases: str | list[str] | None = None + query_macros: dict[str, Any] | None = None @dataclass @@ -41,15 +42,15 @@ class QueryParameter: description: str datatype: str default: Any = None - aliases: Optional[List[str]] = None + aliases: list[str] | None = None @dataclass class QueryDefaults: """Default values for query definitions.""" - metadata: Optional[Dict[str, Any]] = None - parameters: Dict[str, QueryParameter] = field(default_factory=dict) + metadata: dict[str, Any] | None = None + parameters: dict[str, QueryParameter] = field(default_factory=dict) @dataclass @@ -57,7 +58,7 @@ class QueryArgs: """Query arguments.""" query: str = "" - uri: Optional[str] = None + uri: str | None = None @dataclass @@ -66,8 +67,8 @@ class Query: description: str args: QueryArgs = field(default_factory=QueryArgs) - metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # type: ignore - parameters: Optional[Dict[str, QueryParameter]] = field(default_factory=dict) # type: ignore + metadata: dict[str, Any] | None = field(default_factory=dict) + parameters: dict[str, QueryParameter] | None = field(default_factory=dict) @dataclass @@ -75,6 +76,6 @@ class QueryCollection: """Query Collection class - a query template.""" metadata: QueryMetadata - defaults: Optional[QueryDefaults] = None - sources: Dict[str, Query] = field(default_factory=dict) - file_name: Optional[str] = None + defaults: QueryDefaults | None = None + sources: dict[str, Query] = field(default_factory=dict) + file_name: str | None = None diff --git a/msticpy/data/data_obfus.py b/msticpy/data/data_obfus.py index 8339ae026..6ff3a1f1f 100644 --- a/msticpy/data/data_obfus.py +++ b/msticpy/data/data_obfus.py @@ -4,23 +4,27 @@ # license information. # -------------------------------------------------------------------------- """Data obfuscation functions.""" + +from __future__ import annotations + import hashlib import pkgutil import re import uuid import warnings +from collections.abc import Callable, Mapping from functools import lru_cache -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any import numpy as np import pandas as pd import yaml -OBFUS_COL_MAP: Dict[str, str] = {} +OBFUS_COL_MAP: dict[str, str] = {} _MAP_FILE = "resources/obfuscation_cols.yaml" _obfus_map_file = pkgutil.get_data("msticpy", _MAP_FILE) if not _obfus_map_file: - warnings.warn(f"Could not find obfuscation column map {_MAP_FILE}") + warnings.warn(f"Could not find obfuscation column map {_MAP_FILE}", stacklevel=2) else: _obfus_dicts = yaml.safe_load(_obfus_map_file) for data_col_map in _obfus_dicts.values(): @@ -90,13 +94,13 @@ def hash_item(input_item: str, delim: str = None) -> str: # Create a random map for shuffling IP address components -ip_map: List[Dict[str, str]] = [] +ip_map: list[dict[str, str]] = [] for _ in range(4): rng = np.random.default_rng() ip_list = [str(n) for n in np.arange(256)] rand_list = ip_list.copy() rng.shuffle(rand_list) - ip_map.append(dict(zip(ip_list, rand_list))) + ip_map.append(dict(zip(ip_list, rand_list, strict=False))) @lru_cache(maxsize=1024) @@ -131,7 +135,7 @@ def _hash_ip_item(ip_addr: str) -> str: return hashlib.sha256(bytes(ip_addr, "utf-8")).hexdigest()[: len(ip_addr)] -_WK_IPV4 = set(["0.0.0.0", "127.0.0.1", "255.255.255.255"]) # nosec +_WK_IPV4 = {"0.0.0.0", "127.0.0.1", "255.255.255.255"} # nosec def _map_ip4_address(ip_addr: str) -> str: @@ -145,28 +149,19 @@ def _map_ip4_address(ip_addr: str) -> str: if ip_bytes[0] == 10: # class A res private ls_bytes = ".".join( - [ - ip_map[idx].get(byte, "1") - for idx, byte in enumerate(ip_addr.split(".")[1:]) - ] + [ip_map[idx].get(byte, "1") for idx, byte in enumerate(ip_addr.split(".")[1:])] ) return f"10.{ls_bytes}" if ip_bytes[0] == 17 and (16 <= ip_bytes[1] <= 31): # class B res private ls_bytes = ".".join( - [ - ip_map[idx].get(byte, "1") - for idx, byte in enumerate(ip_addr.split(".")[2:]) - ] + [ip_map[idx].get(byte, "1") for idx, byte in enumerate(ip_addr.split(".")[2:])] ) return f"{ip_bytes[0]}.{ip_bytes[1]}.{ls_bytes}" if ip_bytes[0] == 192 and ip_bytes[1] == 168: # class C res private ls_bytes = ".".join( - [ - ip_map[idx].get(byte, "1") - for idx, byte in enumerate(ip_addr.split(".")[2:]) - ] + [ip_map[idx].get(byte, "1") for idx, byte in enumerate(ip_addr.split(".")[2:])] ) return f"192.168.{ls_bytes}" # by default, remap all @@ -175,7 +170,7 @@ def _map_ip4_address(ip_addr: str) -> str: ) -def hash_ip(input_item: Union[List[str], str]) -> Union[List[str], str]: +def hash_ip(input_item: list[str] | str) -> list[str] | str: """ Hash IP address or list of IP addresses. @@ -198,7 +193,7 @@ def hash_ip(input_item: Union[List[str], str]) -> Union[List[str], str]: return _hash_ip_item(input_item) -def hash_list(item_list: List[str]) -> List[Any]: +def hash_list(item_list: list[str]) -> list[Any]: """ Hash list of strings. @@ -213,8 +208,8 @@ def hash_list(item_list: List[str]) -> List[Any]: Hashed list """ - out_list: List[Union[Dict[str, Any], List[Any], str]] = [] - hash_val: Union[str, Dict[str, Any], List[str]] + out_list: list[dict[str, Any] | list[Any] | str] = [] + hash_val: str | dict[str, Any] | list[str] for val in item_list: if isinstance(val, dict): hash_val = hash_dict(val) @@ -226,9 +221,7 @@ def hash_list(item_list: List[str]) -> List[Any]: return out_list -def hash_dict( - item_dict: Dict[str, Union[Dict[str, Any], List[Any], str]] -) -> Dict[str, Any]: +def hash_dict(item_dict: dict[str, dict[str, Any] | list[Any] | str]) -> dict[str, Any]: """ Hash dictionary values. @@ -287,18 +280,16 @@ def hash_sid(sid: str) -> str: return sid -_WK_ACCOUNTS = set( - [ - "administrator", - "guest", - "system", - "local service", - "network service", - "root", - "crontab", - "nt authority", - ] -) +_WK_ACCOUNTS = { + "administrator", + "guest", + "system", + "local service", + "network service", + "root", + "crontab", + "nt authority", +} @lru_cache(maxsize=1024) @@ -351,7 +342,7 @@ def _guid_replacer() -> Callable[[str], str]: replace_guid function """ - guid_map: Dict[str, str] = {} + guid_map: dict[str, str] = {} def _replace_guid(guid: str) -> str: """ @@ -386,7 +377,7 @@ def _replace_guid(guid: str) -> str: # DataFrame obfuscation functions # Map codes to functions -MAP_FUNCS: Dict[str, Union[str, Callable]] = { +MAP_FUNCS: dict[str, str | Callable] = { "uuid": replace_guid, "ip": hash_ip, "str": hash_string, @@ -398,7 +389,7 @@ def _replace_guid(guid: str) -> str: } -def mask_df( # noqa: MC0001 +def mask_df( data: pd.DataFrame, column_map: Mapping[str, Any] = None, use_default: bool = True, @@ -436,7 +427,7 @@ def mask_df( # noqa: MC0001 for col_name in data.columns: if col_name not in col_map: continue - col_type = col_map.get(col_name, "str") # type: ignore + col_type = col_map.get(col_name, "str") if not silent: print(col_name, end=", ") map_func = MAP_FUNCS.get(col_type) @@ -463,7 +454,7 @@ def mask_df( # noqa: MC0001 def check_masking( data: pd.DataFrame, orig_data: pd.DataFrame, index: int = 0, silent=True -) -> Optional[Tuple[List[str], List[str]]]: +) -> tuple[list[str], list[str]] | None: """ Check the obfuscation results for a row. @@ -490,12 +481,11 @@ def check_masking( unchanged = [] obfuscated = [] for col in sorted(data.columns): - if data.iloc[index][col] == orig_data.iloc[index][col]: # type: ignore - unchanged.append(f"{col}: {data.iloc[index][col]}") # type: ignore + if data.iloc[index][col] == orig_data.iloc[index][col]: + unchanged.append(f"{col}: {data.iloc[index][col]}") else: obfuscated.append( - f"{col}: {orig_data.iloc[index][col]} " # type: ignore - f"----> {data.iloc[index][col]}" # type: ignore + f"{col}: {orig_data.iloc[index][col]} ----> {data.iloc[index][col]}" ) if not silent: print("===== Start Check ====") @@ -514,42 +504,3 @@ def check_masking( # alertnative names for backward compat obfuscate_df = mask_df check_obfuscation = check_masking - - -@pd.api.extensions.register_dataframe_accessor("mp_mask") -class ObfuscationAccessor: - """Base64 Unpack pandas extension.""" - - def __init__(self, pandas_obj): - """Initialize the extension.""" - self._df = pandas_obj - - def mask( - self, column_map: Mapping[str, Any] = None, use_default: bool = True - ) -> pd.DataFrame: - """ - Obfuscate the data in columns of a pandas dataframe. - - Parameters - ---------- - data : pd.DataFrame - dataframe containing column to obfuscate - column_map : Mapping[str, Any], optional - Custom column mapping, by default None - use_default: bool - If True use the built-in map (adding any custom - mappings to this dictionary) - - Returns - ------- - pd.DataFrame - Obfuscated dataframe - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp.mask() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return mask_df(data=self._df, column_map=column_map, use_default=use_default) diff --git a/msticpy/data/data_providers.py b/msticpy/data/data_providers.py deleted file mode 100644 index cec115b65..000000000 --- a/msticpy/data/data_providers.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module data_providers.py has moved. - -See :py:mod:`msticpy.data.core.data_providers` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from .core.data_providers import * - -WARN_MSSG = ( - "This module has moved to msticpy.data.core.data_providers\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/drivers/__init__.py b/msticpy/data/drivers/__init__.py index 5d7006053..4182ffa8e 100644 --- a/msticpy/data/drivers/__init__.py +++ b/msticpy/data/drivers/__init__.py @@ -4,9 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Data provider sub-package.""" + import importlib from functools import singledispatch -from typing import Dict from ..._version import VERSION from ..core.query_defns import DataEnvironment @@ -37,14 +37,14 @@ "local_velociraptor_driver", "VelociraptorLogDriver", ), - DataEnvironment.MSSentinel_Legacy: ("kql_driver", "KqlDriver"), - DataEnvironment.Kusto_Legacy: ("kusto_driver", "KustoDriver"), + DataEnvironment.MSSentinel_Legacy: ("azure_monitor_driver", "AzureMonitorDriver"), + DataEnvironment.Kusto_Legacy: ("azure_kusto_driver", "AzureKustoDriver"), DataEnvironment.M365DGraph: ("mdatp_driver", "MDATPDriver"), DataEnvironment.Prismacloud: ("prismacloud_driver", "PrismaCloudDriver"), DataEnvironment.MSSentinelSearch: ("azure_search_driver", "AzureSearchDriver"), } -CUSTOM_PROVIDERS: Dict[str, type] = {} +CUSTOM_PROVIDERS: dict[str, type] = {} @singledispatch @@ -68,9 +68,7 @@ def _(data_environment: DataEnvironment) -> type: ", ".join(env.name for env in _ENVIRONMENT_DRIVERS), ) - imp_module = importlib.import_module( - f"msticpy.data.drivers.{mod_name}", package="msticpy" - ) + imp_module = importlib.import_module(f"msticpy.data.drivers.{mod_name}", package="msticpy") return getattr(imp_module, cls_name) diff --git a/msticpy/data/drivers/azure_kusto_driver.py b/msticpy/data/drivers/azure_kusto_driver.py index 76de95d44..4684bb98b 100644 --- a/msticpy/data/drivers/azure_kusto_driver.py +++ b/msticpy/data/drivers/azure_kusto_driver.py @@ -5,6 +5,7 @@ # license information. # -------------------------------------------------------------------------- """Kusto Driver subclass.""" + from __future__ import annotations import base64 @@ -64,9 +65,7 @@ __version__: str = VERSION __author__: str = "Ian Hellen" -_HELP_URL: str = ( - "https://msticpy.readthedocs.io/en/latest/DataProviders/DataProv-Kusto.html" -) +_HELP_URL: str = "https://msticpy.readthedocs.io/en/latest/DataProviders/DataProv-Kusto.html" logger: logging.Logger = logging.getLogger(__name__) @@ -651,9 +650,7 @@ def _get_connection_string_for_cluster( ) -> KustoConnectionStringBuilder: """Return full cluster URI and credential for cluster name or URI.""" auth_params: AuthParams = self._get_auth_params_from_config(cluster_config) - connect_auth_types: list[str] = ( - self._az_auth_types or AzureCloudConfig().auth_methods - ) + connect_auth_types: list[str] = self._az_auth_types or AzureCloudConfig().auth_methods if auth_params.method == "clientsecret": logger.info("Client secret specified in config - using client secret authn") if "clientsecret" not in connect_auth_types: @@ -715,13 +712,15 @@ def _create_kql_cert_connection_str( encoding=serialization.Encoding.PEM, ) thumbprint: bytes = certificate.fingerprint(hashes.SHA256()) - return KustoConnectionStringBuilder.with_aad_application_certificate_sni_authentication( - connection_string=self.cluster_uri, - aad_app_id=auth_params.params["client_id"], - private_certificate=private_cert.decode("utf-8"), - public_certificate=public_cert.decode("utf-8"), - thumbprint=thumbprint.hex().upper(), - authority_id=self._az_tenant_id, + return ( + KustoConnectionStringBuilder.with_aad_application_certificate_sni_authentication( + connection_string=self.cluster_uri, + aad_app_id=auth_params.params["client_id"], + private_certificate=private_cert.decode("utf-8"), + public_certificate=public_cert.decode("utf-8"), + thumbprint=thumbprint.hex().upper(), + authority_id=self._az_tenant_id, + ) ) def _get_auth_params_from_config( @@ -738,10 +737,7 @@ def _get_auth_params_from_config( logger.info( "Using client secret authentication because client_secret in config", ) - elif ( - KFields.CERTIFICATE in cluster_config - and KFields.CLIENT_ID in cluster_config - ): + elif KFields.CERTIFICATE in cluster_config and KFields.CLIENT_ID in cluster_config: method = "certificate" auth_params_dict["client_id"] = cluster_config.ClientId auth_params_dict["certificate"] = cluster_config.Certificate @@ -945,9 +941,7 @@ def _create_cluster_config( ) -> dict[str, KustoConfig]: """Return a dictionary of Kusto cluster settings from msticpyconfig.yaml.""" return { - config[KFields.ARGS] - .get(KFields.CLUSTER) - .casefold(): KustoConfig( + config[KFields.ARGS].get(KFields.CLUSTER).casefold(): KustoConfig( tenant_id=_setting_or_default( config[KFields.ARGS], KFields.TENANT_ID, @@ -988,8 +982,7 @@ def _section_or_default( ) -> dict[str, Any]: """Return a combined dictionary from the settings dictionary or the default.""" return { - key: settings.get(key, default.get(key)) - for key in (settings.keys() | default.keys()) + key: settings.get(key, default.get(key)) for key in (settings.keys() | default.keys()) } @@ -1029,9 +1022,9 @@ def _parse_query_status(response: KustoResponseDataSet) -> dict[str, Any]: df_status: pd.DataFrame = dataframe_from_result_table( response.tables[query_info_idx], ) - results: list[dict[Hashable, Any]] = df_status[ - ["EventTypeName", "Payload"] - ].to_dict(orient="records") + results: list[dict[Hashable, Any]] = df_status[["EventTypeName", "Payload"]].to_dict( + orient="records" + ) return { row.get("EventTypeName", "Unknown_field"): json.loads( row.get("Payload", "No Payload"), @@ -1087,9 +1080,7 @@ def _raise_not_connected_error() -> NoReturn: def _raise_unknown_query_error(err: Exception) -> NoReturn: """Raise an error if unknown exception raised.""" - err_msg: str = ( - f"Unknown exception when executing query. Exception type: {type(err)}" - ) + err_msg: str = f"Unknown exception when executing query. Exception type: {type(err)}" raise MsticpyDataQueryError( err_msg, *err.args, diff --git a/msticpy/data/drivers/azure_monitor_driver.py b/msticpy/data/drivers/azure_monitor_driver.py index 984d3d738..8ad9ddbc1 100644 --- a/msticpy/data/drivers/azure_monitor_driver.py +++ b/msticpy/data/drivers/azure_monitor_driver.py @@ -14,19 +14,21 @@ azure/monitor-query-readme?view=azure-python """ + from __future__ import annotations import contextlib import logging import warnings -from typing import Any, Iterable, cast +from collections.abc import Iterable +from typing import Any, cast import httpx import pandas as pd from azure.core.exceptions import HttpResponseError from azure.core.pipeline.policies import UserAgentPolicy -from packaging.version import Version -from packaging.version import parse as parse_version +from packaging.version import Version # pylint: disable=no-name-in-module +from packaging.version import parse as parse_version # pylint: disable=no-name-in-module from ..._version import VERSION from ...auth.azure_auth import AzureCloudConfig, az_connect @@ -132,13 +134,9 @@ def __init__(self, connection_str: str | None = None, **kwargs): self.add_query_filter( "data_environments", ("MSSentinel", "LogAnalytics", "AzureSentinel") ) - self.set_driver_property( - DriverProps.EFFECTIVE_ENV, DataEnvironment.MSSentinel.name - ) + self.set_driver_property(DriverProps.EFFECTIVE_ENV, DataEnvironment.MSSentinel.name) self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) - self.set_driver_property( - DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4) - ) + self.set_driver_property(DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4)) self.az_cloud_config = AzureCloudConfig() logger.info( "AzureMonitorDriver loaded. connect_str %s, kwargs: %s", @@ -302,9 +300,7 @@ def query( return data if data is not None else result # pylint: disable=too-many-branches - def query_with_results( - self, query: str, **kwargs - ) -> tuple[pd.DataFrame, dict[str, Any]]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, dict[str, Any]]: """ Execute query string and return DataFrame of results. @@ -336,9 +332,7 @@ def query_with_results( workspace_id = next(iter(self._workspace_ids), None) or self._workspace_id additional_workspaces = self._workspace_ids[1:] if self._workspace_ids else None logger.info("Query to run %s", query) - logger.info( - "Workspaces %s", ",".join(self._workspace_ids) or self._workspace_id - ) + logger.info("Workspaces %s", ",".join(self._workspace_ids) or self._workspace_id) logger.info( "Time span %s - %s", str(time_span_value[0]) if time_span_value else "none", @@ -349,7 +343,7 @@ def query_with_results( result = self._query_client.query_workspace( workspace_id=workspace_id, # type: ignore[arg-type] query=query, - timespan=time_span_value, # type: ignore[arg-type] + timespan=time_span_value, server_timeout=server_timeout, additional_workspaces=additional_workspaces, ) @@ -374,10 +368,11 @@ def query_with_results( warnings.warn( "Partial results returned. This may indicate a query timeout.", RuntimeWarning, + stacklevel=2, ) - table = result.partial_data[0] # type: ignore[attr-defined] + table = result.partial_data[0] else: - table = result.tables[0] # type: ignore[attr-defined] + table = result.tables[0] data_frame = pd.DataFrame(table.rows, columns=table.columns) logger.info("Dataframe returned with %d rows", len(data_frame)) return data_frame, status @@ -398,9 +393,7 @@ def _create_query_client(self, connection_str, **kwargs): # check for additional Args in settings but allow kwargs to override connect_args = self._get_workspace_settings_args() connect_args.update(kwargs) - connect_args.update( - {"auth_methods": az_auth_types, "tenant_id": self._az_tenant_id} - ) + connect_args.update({"auth_methods": az_auth_types, "tenant_id": self._az_tenant_id}) credentials = az_connect(**connect_args) logger.info( "Created query client. Auth type: %s, Url: %s, Proxies: %s", @@ -420,10 +413,7 @@ def _get_workspace_settings_args(self) -> dict[str, Any]: return {} args_path = f"{self._ws_config.settings_path}.Args" args_settings = self._ws_config.settings.get("Args", {}) - return { - name: get_protected_setting(args_path, name) - for name in args_settings.keys() - } + return {name: get_protected_setting(args_path, name) for name in args_settings.keys()} def _get_workspaces(self, connection_str: str | None = None, **kwargs): """Get workspace or workspaces to connect to.""" @@ -441,17 +431,13 @@ def _get_workspaces(self, connection_str: str | None = None, **kwargs): ws_config: WorkspaceConfig | None = None connection_str = connection_str or self._def_connection_str if workspace_name or connection_str is None: - ws_config = WorkspaceConfig(workspace=workspace_name) # type: ignore - logger.info( - "WorkspaceConfig created from workspace name %s", workspace_name - ) + ws_config = WorkspaceConfig(workspace=workspace_name) + logger.info("WorkspaceConfig created from workspace name %s", workspace_name) elif isinstance(connection_str, str): self._def_connection_str = connection_str with contextlib.suppress(ValueError): ws_config = WorkspaceConfig.from_connection_string(connection_str) - logger.info( - "WorkspaceConfig created from connection_str %s", connection_str - ) + logger.info("WorkspaceConfig created from connection_str %s", connection_str) elif isinstance(connection_str, WorkspaceConfig): logger.info("WorkspaceConfig as parameter %s", connection_str.workspace_id) ws_config = connection_str @@ -495,9 +481,9 @@ def _get_workspaces_by_id(self, workspace_ids): def _get_workspaces_by_name(self, workspaces): workspace_configs = { - WorkspaceConfig(workspace)[WorkspaceConfig.CONF_WS_ID]: WorkspaceConfig( - workspace - )[WorkspaceConfig.CONF_TENANT_ID] + WorkspaceConfig(workspace)[WorkspaceConfig.CONF_WS_ID]: WorkspaceConfig(workspace)[ + WorkspaceConfig.CONF_TENANT_ID + ] for workspace in workspaces } if len(set(workspace_configs.values())) > 1: @@ -677,7 +663,7 @@ def _raise_unknown_error(exception): def _schema_format_tables( - ws_tables: dict[str, Iterable[dict[str, Any]]] + ws_tables: dict[str, Iterable[dict[str, Any]]], ) -> dict[str, dict[str, str]]: """Return a sorted dictionary of table names and column names/types.""" table_schema = { @@ -689,9 +675,7 @@ def _schema_format_tables( def _schema_format_columns(table_schema: dict[str, Any]) -> dict[str, str]: """Return a sorted dictionary of column names and types.""" - columns = { - col["name"]: col["type"] for col in table_schema.get("standardColumns", {}) - } + columns = {col["name"]: col["type"] for col in table_schema.get("standardColumns", {})} for col in table_schema.get("customColumns", []): columns[col["name"]] = col["type"] return dict(sorted(columns.items())) diff --git a/msticpy/data/drivers/azure_search_driver.py b/msticpy/data/drivers/azure_search_driver.py index b66bd28e3..4cff4ff71 100644 --- a/msticpy/data/drivers/azure_search_driver.py +++ b/msticpy/data/drivers/azure_search_driver.py @@ -70,9 +70,7 @@ def _create_query_client(self, connection_str: str | None = None, **kwargs): # check for additional Args in settings but allow kwargs to override connect_args = self._get_workspace_settings_args() connect_args.update(kwargs) - connect_args.update( - {"auth_methods": az_auth_types, "tenant_id": self._az_tenant_id} - ) + connect_args.update({"auth_methods": az_auth_types, "tenant_id": self._az_tenant_id}) credentials = az_connect(**connect_args) # This will still set up workspaces and tenant ID @@ -86,9 +84,7 @@ def _create_query_client(self, connection_str: str | None = None, **kwargs): self._connected = True logger.info("Created HTTP-based query client using /search endpoint.") - def query_with_results( - self, query: str, **kwargs - ) -> tuple[pd.DataFrame, dict[str, Any]]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, dict[str, Any]]: """ Execute the query via the /search endpoint and return a DataFrame + result status. @@ -104,9 +100,7 @@ def query_with_results( """ if not self._connected or not hasattr(self, "_auth_header"): - raise MsticpyKqlConnectionError( - "Not connected. Call connect() before querying." - ) + raise MsticpyKqlConnectionError("Not connected. Call connect() before querying.") time_span_value = self._get_time_span_value(**kwargs) if not time_span_value: raise MsticpyDataQueryError( @@ -154,9 +148,7 @@ def query_with_results( def _query_search_endpoint(self, search_url, query_body, timeout): try: with httpx.Client(timeout=timeout) as client: - response = client.post( - search_url, headers=self._auth_header, json=query_body - ) + response = client.post(search_url, headers=self._auth_header, json=query_body) except httpx.RequestError as req_err: logger.error("HTTP request error: %s", req_err) raise MsticpyKqlConnectionError( diff --git a/msticpy/data/drivers/cybereason_driver.py b/msticpy/data/drivers/cybereason_driver.py index c502dea5d..132aaef39 100644 --- a/msticpy/data/drivers/cybereason_driver.py +++ b/msticpy/data/drivers/cybereason_driver.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Cybereason Driver class.""" + from __future__ import annotations import datetime as dt @@ -36,9 +37,7 @@ logger: logging.Logger = logging.getLogger(__name__) -_HELP_URI = ( - "https://msticpy.readthedocs.io/en/latest/data_acquisition/DataProviders.html" -) +_HELP_URI = "https://msticpy.readthedocs.io/en/latest/data_acquisition/DataProviders.html" # pylint: disable=too-many-instance-attributes @@ -224,13 +223,11 @@ def _exec_paginated_queries( # noqa: PLR0913 """ del kwargs - query_tasks: dict[str, partial[dict[str, Any]]] = ( - self._create_paginated_query_tasks( - body=body, - page_size=page_size, - pagination_token=pagination_token, - total_results=total_results, - ) + query_tasks: dict[str, partial[dict[str, Any]]] = self._create_paginated_query_tasks( + body=body, + page_size=page_size, + pagination_token=pagination_token, + total_results=total_results, ) logger.info("Running %s paginated queries.", len(query_tasks)) @@ -527,7 +524,7 @@ async def __run_threaded_queries( ) else: task_iter = as_completed(thread_tasks.values()) - ids_and_tasks: dict[str, Future] = dict(zip(thread_tasks, task_iter)) + ids_and_tasks: dict[str, Future] = dict(zip(thread_tasks, task_iter, strict=False)) for query_id, thread_task in ids_and_tasks.items(): try: result: dict[str, Any] = await thread_task @@ -556,7 +553,9 @@ async def __run_threaded_queries( exc_info=True, ) # Sort the results by the order of the tasks - results = [result for _, result in sorted(zip(thread_tasks, results))] + results = [ + result for _, result in sorted(zip(thread_tasks, results, strict=False)) + ] return pd.concat(results, ignore_index=True) # pylint: disable=too-many-branches @@ -665,7 +664,7 @@ def _recursive_find_and_replace( param_dict: dict[str, Any], ) -> str | dict[str, Any] | list[str] | list[dict[str, Any]]: """Recursively find and replace parameters from query.""" - if isinstance(parameters, (list, str, dict)): + if isinstance(parameters, list | str | dict): return _recursive_find_and_replace(parameters, param_dict) return parameters @@ -693,9 +692,7 @@ def _( ) if isinstance(updated_param, list): result.extend([param for param in updated_param if isinstance(param, str)]) - dict_result.extend( - [param for param in updated_param if isinstance(param, dict)] - ) + dict_result.extend([param for param in updated_param if isinstance(param, dict)]) elif isinstance(updated_param, dict): dict_result.append(updated_param) else: @@ -709,9 +706,7 @@ def _(parameters: str, param_dict: dict[str, Any]) -> str | list[str]: param_regex: str = r"{([^}]+)}" matches: re.Match[str] | None = re.match(param_regex, parameters) if matches: - result: list[str] = [ - param_dict.get(match, parameters) for match in matches.groups() - ] + result: list[str] = [param_dict.get(match, parameters) for match in matches.groups()] if len(result) == 1: return result[0] return result diff --git a/msticpy/data/drivers/driver_base.py b/msticpy/data/drivers/driver_base.py index 9adde3269..bdc2d3acd 100644 --- a/msticpy/data/drivers/driver_base.py +++ b/msticpy/data/drivers/driver_base.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Data driver base class.""" + import abc from abc import ABC from collections import defaultdict -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Any import pandas as pd @@ -35,7 +37,7 @@ class DriverProps: MAX_PARALLEL = "max_parallel" FILTER_ON_CONNECT = "filter_queries_on_connect" - PROPERTY_TYPES: Dict[str, Any] = { + PROPERTY_TYPES: dict[str, Any] = { PUBLIC_ATTRS: dict, FORMATTERS: dict, USE_QUERY_PATHS: bool, @@ -86,8 +88,8 @@ def __init__(self, **kwargs): # self.has_driver_queries = False self._previous_connection = False self.data_environment = kwargs.get("data_environment") - self._query_filter: Dict[str, Set[str]] = defaultdict(set) - self._instance: Optional[str] = None + self._query_filter: dict[str, set[str]] = defaultdict(set) + self._instance: str | None = None self.properties = DriverProps.defaults() self.set_driver_property( @@ -145,7 +147,7 @@ def connected(self) -> bool: return self._connected @property - def instance(self) -> Optional[str]: + def instance(self) -> str | None: """ Return instance name, if one is set. @@ -159,7 +161,7 @@ def instance(self) -> Optional[str]: return self._instance @property - def schema(self) -> Dict[str, Dict]: + def schema(self) -> dict[str, dict]: """ Return current data schema of connection. @@ -172,7 +174,7 @@ def schema(self) -> Dict[str, Dict]: return {} @abc.abstractmethod - def connect(self, connection_str: Optional[str] = None, **kwargs): + def connect(self, connection_str: str | None = None, **kwargs): """ Connect to data source. @@ -185,8 +187,8 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): @abc.abstractmethod def query( - self, query: str, query_source: Optional[QuerySource] = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + self, query: str, query_source: QuerySource | None = None, **kwargs + ) -> pd.DataFrame | Any: """ Execute query string and return DataFrame of results. @@ -212,7 +214,7 @@ def query( """ @abc.abstractmethod - def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, Any]: """ Execute query string and return DataFrame plus native results. @@ -229,7 +231,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: """ @property - def service_queries(self) -> Tuple[Dict[str, str], str]: + def service_queries(self) -> tuple[dict[str, str], str]: """ Return queries retrieved from the service after connecting. @@ -243,7 +245,7 @@ def service_queries(self) -> Tuple[Dict[str, str], str]: return {}, "" @property - def driver_queries(self) -> Iterable[Dict[str, Any]]: + def driver_queries(self) -> Iterable[dict[str, Any]]: """ Return queries retrieved from the service after connecting. @@ -257,11 +259,11 @@ def driver_queries(self) -> Iterable[Dict[str, Any]]: return [{}] @property - def query_attach_spec(self) -> Dict[str, Set[str]]: + def query_attach_spec(self) -> dict[str, set[str]]: """Parameters that determine whether a query is relevant for the driver.""" return self._query_filter - def add_query_filter(self, name: str, query_filter: Union[str, Iterable]): + def add_query_filter(self, name: str, query_filter: str | Iterable): """Add an expression to the query attach filter.""" allowed_names = {"data_environments", "data_families", "data_sources"} if name not in allowed_names: @@ -294,10 +296,10 @@ def query_usable(self, query_source: QuerySource) -> bool: # Read values from configuration @staticmethod - def _get_config_settings(prov_name) -> Dict[Any, Any]: + def _get_config_settings(prov_name) -> dict[Any, Any]: """Get config from msticpyconfig.""" data_provs = get_provider_settings(config_section="DataProviders") - splunk_settings: Optional[ProviderSettings] = data_provs.get(prov_name) + splunk_settings: ProviderSettings | None = data_provs.get(prov_name) return getattr(splunk_settings, "args", {}) @staticmethod diff --git a/msticpy/data/drivers/elastic_driver.py b/msticpy/data/drivers/elastic_driver.py index d0ba72df3..53d8ef71f 100644 --- a/msticpy/data/drivers/elastic_driver.py +++ b/msticpy/data/drivers/elastic_driver.py @@ -4,9 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Elastic Driver class.""" + import json +from collections.abc import Iterable from datetime import datetime -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any import pandas as pd @@ -20,12 +22,12 @@ __author__ = "Neil Desai, Ian Hellen" -ELASTIC_CONNECT_ARGS: Dict[str, str] = { +ELASTIC_CONNECT_ARGS: dict[str, str] = { # TBD - you may not need these - mainly for user # help/error messages (see _get_connect_args) } -_ELASTIC_REQUIRED_ARGS: Dict[str, str] = { +_ELASTIC_REQUIRED_ARGS: dict[str, str] = { # TBD } @@ -78,21 +80,16 @@ def connect(self, connection_str: str = None, **kwargs): self._connected = True print("connected") - def _get_connect_args( - self, connection_str: Optional[str], **kwargs - ) -> Dict[str, Any]: + def _get_connect_args(self, connection_str: str | None, **kwargs) -> dict[str, Any]: """Check and consolidate connection parameters.""" - cs_dict: Dict[str, Any] = {} + cs_dict: dict[str, Any] = {} # Fetch any config settings cs_dict.update(self._get_config_settings("Elastic")) # If a connection string - parse this and add to config if connection_str: cs_items = connection_str.split(";") cs_dict.update( - { - cs_item.split("=")[0].strip(): cs_item.split("=")[1] - for cs_item in cs_items - } + {cs_item.split("=")[0].strip(): cs_item.split("=")[1] for cs_item in cs_items} ) elif kwargs: # if connection args supplied as kwargs @@ -113,7 +110,7 @@ def _get_connect_args( def query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute query and retrieve results. @@ -144,7 +141,7 @@ def query( # Run query and return results return pd.DataFrame() - def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, Any]: """ Execute query string and return DataFrame of results. @@ -177,16 +174,14 @@ def _format_list(param_list: Iterable[Any]) -> str: return ",".join(fmt_list) @staticmethod - def _custom_param_handler(query: str, param_dict: Dict[str, Any]) -> str: + def _custom_param_handler(query: str, param_dict: dict[str, Any]) -> str: """Replace parameters in query template for Elastic JSON queries.""" query_dict = json.loads(query) start = param_dict.pop("start", None) end = param_dict.pop("end", None) if start or end: - time_range = { - "range": {"@timestamp": {"format": "strict_date_optional_time"}} - } + time_range = {"range": {"@timestamp": {"format": "strict_date_optional_time"}}} if start: time_range["range"]["@timestamp"]["gte"] = start if end: diff --git a/msticpy/data/drivers/kql_driver.py b/msticpy/data/drivers/kql_driver.py deleted file mode 100644 index 97aa5988c..000000000 --- a/msticpy/data/drivers/kql_driver.py +++ /dev/null @@ -1,587 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""KQL Driver class.""" - -import contextlib -import json -import logging -import os -import re -import warnings -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -from azure.core.exceptions import ClientAuthenticationError -from IPython.core.getipython import get_ipython - -from ..._version import VERSION -from ...auth.azure_auth import AzureCloudConfig, az_connect -from ...auth.azure_auth_core import only_interactive_cred -from ...common.exceptions import ( - MsticpyDataQueryError, - MsticpyImportExtraError, - MsticpyKqlConnectionError, - MsticpyNoDataSourceError, - MsticpyNotConnectedError, -) -from ...common.utility import MSTICPY_USER_AGENT, export -from ...common.wsconfig import WorkspaceConfig -from ..core.query_defns import DataEnvironment -from .driver_base import DriverBase, DriverProps, QuerySource - -_KQL_ENV_OPTS = "KQLMAGIC_CONFIGURATION" - - -# Need to set KQL option before importing -def _set_kql_env_option(option, value): - """Set an item in the KqlMagic main config environment variable.""" - kql_config = os.environ.get(_KQL_ENV_OPTS, "") - current_opts = { - opt.split("=")[0].strip(): opt.split("=")[1] - for opt in kql_config.split(";") - if opt.strip() and "=" in opt - } - - current_opts[option] = value - kql_config = ";".join(f"{opt}={val}" for opt, val in current_opts.items()) - os.environ[_KQL_ENV_OPTS] = kql_config - - -_set_kql_env_option("enable_add_items_to_help", False) - -try: - from Kqlmagic import kql as kql_exec - from Kqlmagic.kql_engine import KqlEngineError - from Kqlmagic.kql_proxy import KqlResponse - from Kqlmagic.kql_response import KqlError - from Kqlmagic.my_aad_helper import AuthenticationError -except ImportError as imp_err: - raise MsticpyImportExtraError( - "Cannot use this feature without Kqlmagic installed", - "Install msticpy with the [kql] extra or one of the following:", - "%pip install Kqlmagic # notebook", - "python -m pip install Kqlmagic # python", - title="Error importing Kqlmagic", - extra="kql", - ) from imp_err - -__version__ = VERSION -__author__ = "Ian Hellen" - -_KQL_CLOUD_MAP = {"global": "public", "cn": "china", "usgov": "government"} - -_KQL_OPTIONS = ["timeout"] - -_AZ_CLOUD_MAP = {kql_cloud: az_cloud for az_cloud, kql_cloud in _KQL_CLOUD_MAP.items()} - -# pylint: disable=too-many-instance-attributes - - -@export -class KqlDriver(DriverBase): - """KqlDriver class to execute kql queries.""" - - def __init__(self, connection_str: str = None, **kwargs): - """ - Instantiate KqlDriver and optionally connect. - - Parameters - ---------- - connection_str : str, optional - Connection string - - Other Parameters - ---------------- - debug : bool - print out additional diagnostic information. - - """ - self.az_cloud_config = AzureCloudConfig() - self._ip = get_ipython() - self._debug = kwargs.get("debug", False) - super().__init__(**kwargs) - self.workspace_id: Optional[str] = None - self._loaded = self._is_kqlmagic_loaded() - - os.environ["KQLMAGIC_LOAD_MODE"] = "silent" - if not self._loaded: - self._load_kql_magic() - - self._set_kql_option("request_user_agent_tag", MSTICPY_USER_AGENT) - self._set_kql_env_option("enable_add_items_to_help", False) - self._schema: Dict[str, Any] = {} - self.environment = kwargs.pop("data_environment", DataEnvironment.MSSentinel) - self.set_driver_property( - DriverProps.EFFECTIVE_ENV, DataEnvironment.MSSentinel.name - ) - self.kql_cloud, self.az_cloud = self._set_kql_cloud() - for option, value in kwargs.items(): - self._set_kql_option(option, value) - - self.current_connection = "" - self.current_connection_args: Dict[str, Any] = {} - if connection_str: - self.current_connection = connection_str - self.current_connection_args.update(kwargs) - self.connect(connection_str) - - # pylint: disable=too-many-branches - def connect(self, connection_str: Optional[str] = None, **kwargs): # noqa: MC0001 - """ - Connect to data source. - - Parameters - ---------- - connection_str : Union[str, WorkspaceConfig, None] - Connection string or WorkspaceConfig for the Sentinel Workspace. - - Other Parameters - ---------------- - kqlmagic_args : str, optional - Additional string of parameters to be passed to KqlMagic - mp_az_auth : Union[bool, str, list, None], optional - Optional parameter directing KqlMagic to use MSTICPy Azure authentication. - Values can be: - True or "default": use the settings in msticpyconfig.yaml 'Azure' section - str: single auth method name - ('msi', 'cli', 'env', 'vscode', 'powershell', 'cache' or 'interactive') - List[str]: list of acceptable auth methods from - ('msi', 'cli', 'env', 'vscode', 'powershell', 'cache' or 'interactive') - mp_az_tenant_id: str, optional - Optional parameter specifying a Tenant ID for use by MSTICPy Azure - authentication. - workspace : str, optional - Alternative to supplying a WorkspaceConfig object as the connection_str - parameter. Giving a workspace name will fetch the workspace - settings from msticpyconfig.yaml. - - - """ - if not self._previous_connection: - print("Connecting...", end=" ") - - mp_az_auth = kwargs.get("mp_az_auth", "default") - mp_az_tenant_id = kwargs.get("mp_az_tenant_id") - workspace = kwargs.get("workspace") - if workspace or connection_str is None: - connection_str = WorkspaceConfig(workspace=workspace) # type: ignore - - if isinstance(connection_str, WorkspaceConfig): - if not mp_az_tenant_id and "tenant_id" in connection_str: - mp_az_tenant_id = connection_str["tenant_id"] - self._instance = connection_str.workspace_key - connection_str = connection_str.code_connect_str - - if not connection_str: - raise MsticpyKqlConnectionError( - f"A connection string is needed to connect to {self._connect_target}", - title="no connection string", - ) - if "kqlmagic_args" in kwargs: - connection_str = f"{connection_str} {kwargs['kqlmagic_args']}" - - # Default to using Azure Auth if possible. - if mp_az_auth and "try_token" not in kwargs: - self._set_az_auth_option(mp_az_auth, mp_az_tenant_id) - - self.current_connection = connection_str - ws_in_connection = re.search( - r"workspace\(['\"]([^'\"]+).*", - self.current_connection, - re.IGNORECASE, - ) - self.workspace_id = ws_in_connection[1] if ws_in_connection else None - self.current_connection_args.update(kwargs) - kql_err_setting = self._get_kql_option("short_errors") - self._connected = False - try: - self._set_kql_option("short_errors", False) - if self._ip is not None: - try: - kql_exec(connection_str) - if not self._previous_connection: - print("connected") - except KqlError as ex: - self._raise_kql_error(ex) - except KqlEngineError as ex: - self._raise_kql_engine_error(ex) - except AuthenticationError as ex: - self._raise_authn_error(ex) - except Exception as ex: # pylint: disable=broad-except - self._raise_adal_error(ex) - self._connected = True - self._previous_connection = True - self._schema = self._get_schema() - else: - print(f"Could not connect to kql query provider for {connection_str}") - return self._connected - finally: - self._set_kql_option("short_errors", kql_err_setting) - - # pylint: disable=too-many-branches - - @property - def schema(self) -> Dict[str, Dict]: - """ - Return current data schema of connection. - - Returns - ------- - Dict[str, Dict] - Data schema of current connection. - - """ - return self._schema - - def query( - self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: - """ - Execute query string and return DataFrame of results. - - Parameters - ---------- - query : str - The query to execute - query_source : QuerySource - The query definition object - - Returns - ------- - Union[pd.DataFrame, results.ResultSet] - A DataFrame (if successful) or - the underlying provider result if an error. - - """ - if query_source: - try: - table = query_source["args.table"] - except KeyError: - table = None - if table: - if " " in table.strip(): - table = table.strip().split(" ")[0] - if table not in self.schema: - raise MsticpyNoDataSourceError( - f"The table {table} for this query is not in your workspace", - " or database schema. Please check your this", - title=f"{table} not found.", - ) - data, result = self.query_with_results(query, **kwargs) - return data if data is not None else result - - # pylint: disable=too-many-branches - def query_with_results( - self, query: str, **kwargs - ) -> Tuple[pd.DataFrame, KqlResponse]: - """ - Execute query string and return DataFrame of results. - - Parameters - ---------- - query : str - The kql query to execute - - Returns - ------- - Tuple[pd.DataFrame, results.ResultSet] - A DataFrame (if successful) and - Kql ResultSet. - - """ - debug = kwargs.pop("debug", self._debug) - if debug: - print(query) - - if ( - not self.connected - or self.workspace_id != self._get_kql_current_connection() - ): - self._make_current_connection() - - # save current auto_dataframe setting so that we can set to false - # and restore current setting - auto_dataframe = self._get_kql_option(option="auto_dataframe") - self._set_kql_option(option="auto_dataframe", value=False) - - # run the query (append semicolon to prevent default output) - if not query.strip().endswith(";"): - query = f"{query}\n;" - - # Add any Kqlmagic options from kwargs - kql_opts = { - option: option_val - for option, option_val in kwargs.items() - if option in _KQL_OPTIONS - } - result = kql_exec(query, options=kql_opts) - self._set_kql_option(option="auto_dataframe", value=auto_dataframe) - if result is not None: - if isinstance(result, pd.DataFrame): - return result, None - if hasattr(result, "completion_query_info") and ( - int(result.completion_query_info.get("StatusCode", 1)) == 0 - or result.completion_query_info.get("Text") - == "Query completed successfully" - ): - data_frame = result.to_dataframe() - if result.is_partial_table: - print("Warning - query returned partial results.") - if debug: - print("Query status:\n", "\n".join(self._get_query_status(result))) - return data_frame, result - - return self._raise_query_failure(query, result) - - def _make_current_connection(self): - """Switch to the current connection (self.current_connection).""" - try: - self.connect(self.current_connection, **(self.current_connection_args)) - except MsticpyKqlConnectionError: - self._connected = False - if not self.connected: - raise MsticpyNotConnectedError( - "Please run the connect() method before running a query.", - title=f"not connected to a {self._connect_target}", - help_uri=MsticpyKqlConnectionError.DEF_HELP_URI, - ) - - def _load_kql_magic(self): - """Load KqlMagic if not loaded.""" - # KqlMagic - print("Please wait. Loading Kqlmagic extension...", end="") - if self._ip is not None: - with warnings.catch_warnings(): - # Suppress logging exception about PyGObject from msal_extensions - msal_ext_logger = logging.getLogger("msal_extensions.libsecret") - current_level = msal_ext_logger.getEffectiveLevel() - msal_ext_logger.setLevel(logging.CRITICAL) - warnings.simplefilter(action="ignore") - self._ip.run_line_magic("reload_ext", "Kqlmagic") - msal_ext_logger.setLevel(current_level) - self._loaded = True - print("done") - - def _is_kqlmagic_loaded(self) -> bool: - """Return true if kql magic is loaded.""" - if self._ip is not None: - return self._ip.find_magic("kql") is not None - return bool(kql_exec("--version")) - - @property - def _connect_target(self) -> str: - if self.environment == DataEnvironment.MSSentinel: - return "Workspace" - return "Kusto cluster" - - @staticmethod - def _get_query_status(result) -> List[str]: - return [f"{key}: '{value}'" for key, value in result.completion_query_info] - - @staticmethod - def _get_schema() -> Dict[str, Dict]: - return kql_exec("--schema") - - @staticmethod - def _get_kql_option(option): - """Retrieve a current Kqlmagic notebook option.""" - return kql_exec(f"--config {option}").get(option) - - @staticmethod - def _set_kql_option(option, value): - """Set a Kqlmagic notebook option.""" - kql_exec("--config short_errors=False") - result: Any - try: - opt_val = f"'{value}'" if isinstance(value, str) else value - result = kql_exec(f"--config {option}={opt_val}") - except ValueError: - result = None - finally: - kql_exec("--config short_errors=True") - return result - - @staticmethod - def _set_kql_env_option(option, value): - """Set an item in the KqlMagic main config environment variable.""" - kql_config = os.environ.get(_KQL_ENV_OPTS, "") - current_opts = { - opt.split("=")[0].strip(): opt.split("=")[1] - for opt in kql_config.split(";") - } - current_opts[option] = value - kql_config = ";".join(f"{opt}={val}" for opt, val in current_opts.items()) - # print(kql_config) - replace with logger - os.environ[_KQL_ENV_OPTS] = kql_config - - @staticmethod - def _get_kql_current_connection(): - """Get the current connection Workspace ID from KQLMagic.""" - connections = kql_exec("--conn") - current_connection = [conn for conn in connections if conn.startswith(" * ")] - if not current_connection: - return "" - return current_connection[0].strip(" * ").split("@")[0] - - def _set_kql_cloud(self): - """If cloud is set in Azure Settings override default.""" - # Check that there isn't a cloud setting in the KQLMAGIC env var - kql_config = os.environ.get(_KQL_ENV_OPTS, "") - if "cloud" in kql_config: - # Set by user - we don't want to override this - kql_cloud = self._get_kql_option("cloud") - az_cloud = _AZ_CLOUD_MAP.get(kql_cloud, "public") - return kql_cloud, az_cloud - az_cloud = self.az_cloud_config.cloud - kql_cloud = _KQL_CLOUD_MAP.get(az_cloud, "public") - if kql_cloud != self._get_kql_option("cloud"): - self._set_kql_option("cloud", kql_cloud) - return kql_cloud, az_cloud - - @staticmethod - def _raise_query_failure(query, result): - """Raise query failure exception.""" - err_contents = [] - if hasattr(result, "completion_query_info"): - q_info = result.completion_query_info - if "StatusDescription" in q_info: - err_contents = [ - f"StatusDescription {q_info.get('StatusDescription')}", - f"(err_code: {result.completion_query_info.get('StatusCode')})", - ] - elif "Text" in q_info: - err_contents = [f"StatusDescription {q_info.get('Text')}"] - else: - err_contents = [f"Unknown error type: {q_info}"] - if not err_contents: - err_contents = ["Unknown query error"] - - err_contents.append(f"Query:\n{query}") - raise MsticpyDataQueryError(*err_contents) - - _WS_RGX = r"workspace\(['\"](?P[^'\"]+)" - _TEN_RGX = r"tenant\(['\"](?P[^'\"]+)" - - def _raise_kql_error(self, ex): - kql_err = json.loads(ex.args[0]).get("error") - if kql_err.get("code") == "WorkspaceNotFoundError": - ex_mssgs = [ - "The workspace ID used to connect to Microsoft Sentinel could not be found.", - "Please check that this is a valid workspace for your subscription", - ] - ws_match = re.search(self._WS_RGX, self.current_connection, re.IGNORECASE) - if ws_match: - ws_name = ws_match.groupdict().get("ws") - ex_mssgs.append(f"The workspace id used was {ws_name}.") - ex_mssgs.append(f"The full connection string was {self.current_connection}") - raise MsticpyKqlConnectionError(*ex_mssgs, title="unknown workspace") - raise MsticpyKqlConnectionError( - "The service returned the following error when connecting", - str(ex), - title="Kql response error", - ) - - @staticmethod - def _raise_kql_engine_error(ex): - ex_mssgs = [ - "An error was returned from Kqlmagic KqlEngine.", - "This can occur if you tried to connect to a second workspace using a" - + " different tenant ID - only a single tenant ID is supported in" - + " one notebook.", - "Other causes of this error could be an invalid format of your" - + " connection string", - *(ex.args), - ] - raise MsticpyKqlConnectionError(*ex_mssgs, title="kql connection error") - - @staticmethod - def _raise_adal_error(ex): - """Adal error - usually wrong tenant ID.""" - if ex.args and ex.args[0] == "Unexpected polling state code_expired": - raise MsticpyKqlConnectionError( - "Authentication request was not completed.", - title="authentication timed out", - ) - - err_response = getattr(ex, "error_response", None) - if err_response and "error_description" in ex.error_response: - ex_mssgs = ex.error_response["error_description"].split("\r\n") - else: - ex_mssgs = [f"Full error: {ex}"] - raise MsticpyKqlConnectionError( - *ex_mssgs, title="could not authenticate to tenant" - ) - - @staticmethod - def _raise_authn_error(ex): - """Raise an authentication error.""" - ex_mssgs = [ - "The authentication failed.", - "Please check the credentials you are using and permissions on the ", - "workspace or cluster.", - *(ex.args), - ] - raise MsticpyKqlConnectionError(*ex_mssgs, title="authentication failed") - - @staticmethod - def _raise_unknown_error(ex): - """Raise an unknown exception.""" - raise MsticpyKqlConnectionError( - "Another exception was returned by the service", - *ex.args, - f"Full exception:\n{ex}", - title="connection failed", - ) - - def _set_az_auth_option( - self, mp_az_auth: Union[bool, str, list, None], mp_az_tenant_id: str = None - ): - """ - Build connection string with auth elements. - - Parameters - ---------- - mp_az_auth : Union[bool, str, list, None], optional - Optional parameter directing KqlMagic to use MSTICPy Azure authentication. - Values can be: - - True or "default": use the settings in msticpyconfig.yaml 'Azure' section - - auth_method: single auth method name ('msi', 'cli', 'env' or 'interactive') - - auth_methods: list of acceptable auth methods from ('msi', 'cli', - 'env' or 'interactive') - mp_az_tenant_id: str, optional - Optional parameter specifying a Tenant ID for use by MSTICPy Azure - authentication. - - """ - # default to default auth methods - auth_types = self.az_cloud_config.auth_methods - # override if user-supplied methods on command line - if isinstance(mp_az_auth, str) and mp_az_auth != "default": - auth_types = [mp_az_auth] - elif isinstance(mp_az_auth, list): - auth_types = mp_az_auth - # get current credentials - creds = az_connect(auth_methods=auth_types, tenant_id=mp_az_tenant_id) - if only_interactive_cred(creds.modern): - print("Check your default browser for interactive sign-in prompt.") - - endpoint_uri = self._get_endpoint_uri() - endpoint_token_uri = f"{endpoint_uri}.default" - # obtain token for the endpoint - with contextlib.suppress(ClientAuthenticationError): - token = creds.modern.get_token( - endpoint_token_uri, tenant_id=mp_az_tenant_id - ) - # set the token values in the namespace - endpoint_token = { - "access_token": token.token, - "token_type": "Bearer", - "resource": endpoint_uri, - } - self._set_kql_option("try_token", endpoint_token) - - def _get_endpoint_uri(self): - return self.az_cloud_config.log_analytics_uri diff --git a/msticpy/data/drivers/kusto_driver.py b/msticpy/data/drivers/kusto_driver.py deleted file mode 100644 index be3dca3bd..000000000 --- a/msticpy/data/drivers/kusto_driver.py +++ /dev/null @@ -1,296 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Kusto Driver subclass.""" -from typing import Any, Dict, Optional, Union - -import pandas as pd - -from ..._version import VERSION -from ...common.exceptions import MsticpyParameterError, MsticpyUserConfigError -from ...common.provider_settings import ProviderArgs, get_provider_settings -from ...common.utility import export -from ..core.query_defns import DataEnvironment -from .kql_driver import DriverProps, KqlDriver, QuerySource - -__version__ = VERSION -__author__ = "Ian Hellen" - -_KCS_CODE = "code;" -_KCS_APP = "tenant='{tenant_id}';clientid='{client_id}';clientsecret='{clientsecret}';" -_KCS_TEMPLATE = "azure_data-Explorer://{auth}cluster='{cluster}';database='{database}'" - -KustoClusterSettings = Dict[str, Dict[str, Union[str, ProviderArgs]]] - - -@export -class KustoDriver(KqlDriver): - """Kusto Driver class to execute kql queries for Azure Data Explorer.""" - - def __init__(self, connection_str: str = None, **kwargs): - """ - Instantiate KustoDriver. - - Parameters - ---------- - connection_str : str, optional - Connection string - - Other Parameters - ---------------- - debug : bool - print out additional diagnostic information. - - """ - super().__init__(connection_str=connection_str, **kwargs) - self.environment = kwargs.get("data_environment", DataEnvironment.Kusto) - self.set_driver_property(DriverProps.EFFECTIVE_ENV, DataEnvironment.Kusto.name) - self._connected = True - self._kusto_settings: KustoClusterSettings = _get_kusto_settings() - self._cluster_uri = None - - def connect(self, connection_str: Optional[str] = None, **kwargs): - """ - Connect to data source. - - Parameters - ---------- - connection_str : str - Connect to a data source - - Other Parameters - ---------------- - cluster : str, optional - Short name or URI of cluster to connect to. - database : str, optional - Name of database to connect to. - kqlmagic_args : str, optional - Additional string of parameters to be passed to KqlMagic - mp_az_auth : Union[bool, str, list, None], optional - Optional parameter directing KqlMagic to use MSTICPy Azure authentication. - Values can be: - True or "default": use the settings in msticpyconfig.yaml 'Azure' section - str: single auth method name - ('msi', 'cli', 'env', 'vscode', 'powershell', 'cache' or 'interactive') - List[str]: list of acceptable auth methods from - ('msi', 'cli', 'env', 'vscode', 'powershell', 'cache' or 'interactive') - mp_az_tenant_id: str, optional - Optional parameter specifying a Tenant ID for use by MSTICPy Azure - authentication. - - """ - self.current_connection = self._get_connection_string( - connection_str=connection_str, **kwargs - ) - - mp_az_auth = kwargs.pop("mp_az_auth", None) - mp_az_tenant_id = kwargs.pop("mp_az_tenant_id", None) - - if ( - self._cluster_uri - ): # This should be set by _get_connection_string called above - cluster_settings = self._kusto_settings.get(self._cluster_uri.casefold()) - if cluster_settings: - if mp_az_auth is None and cluster_settings["integrated_auth"]: - mp_az_auth = "default" - if mp_az_tenant_id is None and cluster_settings["tenant_id"]: - mp_az_tenant_id = cluster_settings["tenant_id"] - - kwargs.pop("cluster", None) - kwargs.pop("database", None) - - super().connect( - connection_str=self.current_connection, - mp_az_auth=mp_az_auth, - mp_az_tenant_id=mp_az_tenant_id, - **kwargs, - ) - - def query( - self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: - """ - Execute query string and return DataFrame of results. - - Parameters - ---------- - query : str - The query to execute - query_source : QuerySource - The query definition object - - Other Parameters - ---------------- - cluster : str, Optional - Supply or override the Kusto cluster name - database : str, Optional - Supply or override the Kusto database name - data_source : str, Optional - alias for `db` - connection_str : str, Optional - - - Returns - ------- - Union[pd.DataFrame, results.ResultSet] - A DataFrame (if successful) or - the underlying provider result if an error. - - """ - new_connection = self._get_connection_string( - query_source=query_source, **kwargs - ) - if new_connection: - self.current_connection = new_connection - data, result = self.query_with_results(query) - return data if data is not None else result - - def _get_connection_string(self, query_source: QuerySource = None, **kwargs): - """Create a connection string from arguments and configuration.""" - # If the connection string is supplied as a parameter, use that - cluster = None - new_connection = kwargs.get("connection_str") - database = kwargs.get("database") - if not new_connection: - # try to get cluster and db from kwargs or query_source metadata - cluster = self._lookup_cluster(kwargs.get("cluster", "Kusto")) - if cluster and database: - new_connection = self._create_connection( - cluster=cluster, database=database - ) - self._cluster_uri = cluster - if not new_connection and query_source: - # try to get cluster and db from query_source metadata - cluster = cluster or query_source.metadata.get("cluster") - database = ( - database - or query_source.metadata.get("database") - or self._get_db_from_datafamily(query_source, cluster, database) - ) - new_connection = self._create_connection(cluster=cluster, database=database) - self._cluster_uri = cluster - return new_connection - - def _get_db_from_datafamily(self, query_source, cluster, database): - data_families = query_source.metadata.get("data_families") - if ( - not isinstance(data_families, list) or len(data_families) == 0 - ) and not self.current_connection: - # call create connection so that we throw an informative error - self._create_connection(cluster=cluster, database=database) - if "." in data_families[0]: # type: ignore - _, qry_db = data_families[0].split(".", maxsplit=1) # type: ignore - else: - # Not expected but we can still use a DB value with no dot - qry_db = data_families[0] # type: ignore - return qry_db - - def _create_connection(self, cluster, database): - """Create the connection string, checking parameters.""" - if not cluster or not database: - if cluster: - err_mssg = "database name" - elif database: - err_mssg = "cluster uri" - else: - err_mssg = "cluster uri and database name" - raise MsticpyParameterError( - f"Could not determine the {err_mssg} for the query.", - "Please update the query with the correct values or specify", - "explicitly with the 'cluster' and 'database' parameters to", - "this function.", - "In the query template these values are specified in the metadata:", - "cluster: cluster_uri", - "data_families: [ClusterAlias.database]", - title="Missing cluster or database names.", - parameter=err_mssg, - ) - cluster_key = cluster.casefold() - if cluster_key not in self._kusto_settings: - raise MsticpyUserConfigError( - f"The cluster {cluster} was not found in the configuration.", - "You must have an entry for the cluster in the 'DataProviders section", - "of your msticyconfig.yaml", - "Expected format:", - "Kusto[-instance_name]:", - " Args:", - " Cluster: cluster_uri", - " Integrated: True", - "or", - "Kusto[-instance_name]:", - " Args:", - " Cluster: cluster_uri", - " TenantId: tenant_uuid", - " ClientId: tenant_uuid", - " ClientSecret: (string|KeyVault|EnvironmentVar:)", - title="Unknown cluster.", - ) - return self._format_connection_str(cluster, database) - - def _format_connection_str(self, cluster: str, database: str) -> Optional[str]: - """Return connection string with client secret added.""" - fmt_items = self._kusto_settings.get(cluster.casefold()) - if not fmt_items: - return None - fmt_items["database"] = database - if fmt_items.get("integrated_auth"): - auth_string = _KCS_CODE - else: - # Note, we don't add the secret until required at runtime to prevent - # it hanging around in memory as much as possible. - fmt_items["clientsecret"] = fmt_items["args"].get("ClientSecret") # type: ignore - auth_string = _KCS_APP.format(**fmt_items) - return _KCS_TEMPLATE.format(auth=auth_string, **fmt_items) - - def _lookup_cluster(self, cluster: str): - """Return cluster URI from config if cluster name is passed.""" - if cluster.strip().casefold().startswith("https://"): - return cluster - return next( - ( - kusto_config["cluster"] - for cluster_key, kusto_config in self._kusto_settings.items() - if ( - cluster_key.startswith(f"https://{cluster.casefold()}.") - or ( - kusto_config.get("alias", "").casefold() # type: ignore - == cluster.casefold() - ) - ) - ), - None, - ) - - def _get_endpoint_uri(self): - if not self._cluster_uri.endswith("/"): - self._cluster_uri += "/" - return self._cluster_uri - - -def _get_kusto_settings() -> KustoClusterSettings: - kusto_settings: KustoClusterSettings = {} - for prov_name, settings in get_provider_settings("DataProviders").items(): - if not prov_name.startswith("Kusto"): - continue - instance = "Kusto" - if "-" in prov_name: - _, instance = prov_name.split("-", maxsplit=1) - - cluster = settings.args.get("Cluster") - if not cluster: - raise MsticpyUserConfigError( - "Mandatory 'Cluster' setting is missing in msticpyconfig.", - f"the Kusto entry with the missing setting is '{prov_name}'", - title=f"No Cluster value for {prov_name}", - ) - kusto_settings[cluster.casefold()] = { - "tenant_id": settings.args.get("TenantId"), # type: ignore - "integrated_auth": settings.args.get("IntegratedAuth"), # type: ignore - "client_id": settings.args.get("ClientId"), # type: ignore - "args": settings.args, - "cluster": cluster, - "alias": instance, - } - return kusto_settings diff --git a/msticpy/data/drivers/local_data_driver.py b/msticpy/data/drivers/local_data_driver.py index 5c29a766b..458d1b9f1 100644 --- a/msticpy/data/drivers/local_data_driver.py +++ b/msticpy/data/drivers/local_data_driver.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Local Data Driver class - for testing and demos.""" + from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import pandas as pd @@ -38,19 +39,19 @@ def __init__(self, connection_str: str = None, **kwargs): self._debug = kwargs.get("debug", False) super().__init__(**kwargs) - self._paths: List[str] = ["."] + self._paths: list[str] = ["."] if data_paths := kwargs.get("data_paths"): self._paths = [path.strip() for path in data_paths] elif has_config("DataProviders.LocalData"): self._paths = get_config("LocalData.data_paths", self._paths) - self.data_files: Dict[str, str] = self._get_data_paths() - self._schema: Dict[str, Any] = {} + self.data_files: dict[str, str] = self._get_data_paths() + self._schema: dict[str, Any] = {} self._loaded = True self._connected = True self.current_connection = "; ".join(self._paths) - def _get_data_paths(self) -> Dict[str, str]: + def _get_data_paths(self) -> dict[str, str]: """Read files in data paths.""" data_files = {} for path in self._paths: @@ -65,7 +66,7 @@ def _get_data_paths(self) -> Dict[str, str]: ) return data_files - def connect(self, connection_str: Optional[str] = None, **kwargs): + def connect(self, connection_str: str | None = None, **kwargs): """ Connect to data source. @@ -80,7 +81,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): print("Connected.") @property - def schema(self) -> Dict[str, Dict]: + def schema(self) -> dict[str, dict]: """ Return current data schema of connection. @@ -105,7 +106,7 @@ def schema(self) -> Dict[str, Dict]: def query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute query string and return DataFrame of results. @@ -127,9 +128,7 @@ def query( query_name = query_source.name if query_source else query file_path = self.data_files.get(query.casefold()) if not file_path: - raise FileNotFoundError( - f"Data file ({query}) for query {query_name} not found." - ) + raise FileNotFoundError(f"Data file ({query}) for query {query_name} not found.") if file_path.endswith("csv"): try: return pd.read_csv(file_path, parse_dates=["TimeGenerated"]) diff --git a/msticpy/data/drivers/local_osquery_driver.py b/msticpy/data/drivers/local_osquery_driver.py index 6805be51c..6ee073445 100644 --- a/msticpy/data/drivers/local_osquery_driver.py +++ b/msticpy/data/drivers/local_osquery_driver.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Local Osquery Data Driver class - osquery.{results,snapshots}.log.""" + import json import logging @@ -13,7 +14,7 @@ import re from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import pandas as pd from pandas import to_datetime, to_numeric @@ -64,8 +65,8 @@ def __init__(self, connection_str: str = None, **kwargs): del connection_str self._debug = kwargs.get("debug", False) super().__init__() - self._cache_file: Optional[str] = None - self._paths: List[str] = ["."] + self._cache_file: str | None = None + self._paths: list[str] = ["."] # If data paths specified, use these # from kwargs or settings if data_paths := kwargs.get("data_paths"): @@ -81,17 +82,17 @@ def __init__(self, connection_str: str = None, **kwargs): logger.info("data paths read from config %s", str(self._paths)) self._progress = kwargs.pop("progress", True) - self.data_files: Dict[str, str] = self._get_logfile_paths() - self._schema: Dict[str, Any] = {} - self._data_cache: Dict[str, pd.DataFrame] = {} - self._query_map: Dict[str, List[str]] + self.data_files: dict[str, str] = self._get_logfile_paths() + self._schema: dict[str, Any] = {} + self._data_cache: dict[str, pd.DataFrame] = {} + self._query_map: dict[str, list[str]] self._cache_file = kwargs.pop("cache_file", self._cache_file) self._loaded = True self.has_driver_queries = True logger.info("data files to read %s", ",".join(self.data_files.values())) logger.info("cache file %s", self._cache_file) - def _get_logfile_paths(self) -> Dict[str, str]: + def _get_logfile_paths(self) -> dict[str, str]: """Read files in data paths.""" data_files = {} for input_path in (Path(path_str) for path_str in self._paths): @@ -110,7 +111,7 @@ def _get_logfile_paths(self) -> Dict[str, str]: ) return data_files - def connect(self, connection_str: Optional[str] = None, **kwargs): + def connect(self, connection_str: str | None = None, **kwargs): """ Connect to data source. @@ -127,7 +128,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): # pylint: disable=too-many-branches @property - def schema(self) -> Dict[str, Dict]: + def schema(self) -> dict[str, dict]: """ Return current data schema of connection. @@ -148,7 +149,7 @@ def schema(self) -> Dict[str, Dict]: def query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute query string and return DataFrame of results. @@ -179,9 +180,7 @@ def query( df_names = self._query_map[query_name] query_df = pd.concat([self._data_cache[df] for df in df_names]) for date_column in self.OS_QUERY_DATEIME_COLS & set(query_df.columns): - query_df[date_column] = to_datetime( - query_df[date_column], unit="s", origin="unix" - ) + query_df[date_column] = to_datetime(query_df[date_column], unit="s", origin="unix") logger.info("Query %s, returned %d rows", query_name, len(query_df)) return query_df @@ -190,7 +189,7 @@ def query_with_results(self, query, **kwargs): return self.query(query, **kwargs), "OK" @property - def driver_queries(self) -> List[Dict[str, Any]]: + def driver_queries(self) -> list[dict[str, Any]]: """ Return dynamic queries available on connection to data. @@ -234,9 +233,7 @@ def _read_data_files(self): # Otherwise read in the data files. data_files = ( - tqdm(self.data_files.values()) - if self._progress - else self.data_files.values() + tqdm(self.data_files.values()) if self._progress else self.data_files.values() ) for log_file in data_files: self._read_log_file(log_file) @@ -284,16 +281,14 @@ def _read_log_file(self, log_path: str): # Likely resource intensive and better way to do. # Likely issue, multiple log files can contain same query mostly # because of log rotation - list_lines: List[Dict[str, Any]] = [] + list_lines: list[dict[str, Any]] = [] try: - with open(log_path, mode="r", encoding="utf-8") as logfile: + with open(log_path, encoding="utf-8") as logfile: json_lines = logfile.readlines() list_lines = [json.loads(line) for line in json_lines] - except (IOError, json.JSONDecodeError, ValueError) as exc: - raise MsticpyDataQueryError( - f"Read error on file {log_path}: {exc}." - ) from exc + except (OSError, json.JSONDecodeError, ValueError) as exc: + raise MsticpyDataQueryError(f"Read error on file {log_path}: {exc}.") from exc if not list_lines: raise MsticpyNoDataSourceError( f"No log data retrieved from {log_path}", @@ -302,9 +297,7 @@ def _read_log_file(self, log_path: str): logger.info("log %s read, %d lines read", log_path, len(list_lines)) df_all_queries = pd.json_normalize(list_lines, max_level=3) # Don't want dot in columns name - df_all_queries.columns = df_all_queries.columns.str.replace( - ".", "_", regex=False - ) + df_all_queries.columns = df_all_queries.columns.str.replace(".", "_", regex=False) for event_name in df_all_queries["name"].unique().tolist(): combined_dfs = [] @@ -346,11 +339,7 @@ def _rename_columns(data: pd.DataFrame): for prefix in _PREFIXES: source_cols = data.filter(regex=f"{prefix}.*").columns rename_cols.update( - { - col: col.replace(prefix, "") - for col in source_cols - if isinstance(col, str) - } + {col: col.replace(prefix, "") for col in source_cols if isinstance(col, str)} ) rename_cols = { col: ren_col if ren_col not in df_cols else f"{ren_col}_" diff --git a/msticpy/data/drivers/local_velociraptor_driver.py b/msticpy/data/drivers/local_velociraptor_driver.py index 7d91dc213..4803ec97a 100644 --- a/msticpy/data/drivers/local_velociraptor_driver.py +++ b/msticpy/data/drivers/local_velociraptor_driver.py @@ -4,11 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Local Velociraptor Data Driver class.""" + import logging from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import pandas as pd from tqdm.auto import tqdm @@ -25,8 +26,7 @@ logger = logging.getLogger(__name__) _VELOCIRATOR_DOC_URL = ( - "https://msticpy.readthedocs.io/en/latest/data_acquisition/" - "DataProv-Velociraptor.html" + "https://msticpy.readthedocs.io/en/latest/data_acquisition/DataProv-Velociraptor.html" ) @@ -35,7 +35,7 @@ class VelociraptorLogDriver(DriverBase): """Velociraptor driver class to ingest log data.""" - def __init__(self, connection_str: Optional[str] = None, **kwargs): + def __init__(self, connection_str: str | None = None, **kwargs): """ Instantiate Velociraptor driver and optionally connect. @@ -54,7 +54,7 @@ def __init__(self, connection_str: Optional[str] = None, **kwargs): logger.setLevel(logging.DEBUG) super().__init__() - self._paths: List[str] = ["."] + self._paths: list[str] = ["."] # If data paths specified, use these # from kwargs or settings if data_paths := kwargs.get("data_paths"): @@ -68,15 +68,15 @@ def __init__(self, connection_str: Optional[str] = None, **kwargs): self._paths = prov_settings.args.get("data_paths", []) or self._paths logger.info("data paths read from config %s", str(self._paths)) - self.data_files: Dict[str, List[Path]] = {} - self._schema: Dict[str, Any] = {} - self._query_map: Dict[str, List[str]] + self.data_files: dict[str, list[Path]] = {} + self._schema: dict[str, Any] = {} + self._query_map: dict[str, list[str]] self._progress = kwargs.pop("progress", True) self._loaded = True self.has_driver_queries = True logger.info("data files to read %s", ",".join(self.data_files)) - def connect(self, connection_str: Optional[str] = None, **kwargs): + def connect(self, connection_str: str | None = None, **kwargs): """ Connect to data source. @@ -91,7 +91,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): self._connected = True @property - def schema(self) -> Dict[str, Dict]: + def schema(self) -> dict[str, dict]: """ Return current data schema of connection. @@ -106,9 +106,7 @@ def schema(self) -> Dict[str, Dict]: self.connect() # read the first row of each file to get the schema iter_data_files = ( - tqdm(self.data_files.items()) - if self._progress - else self.data_files.items() + tqdm(self.data_files.items()) if self._progress else self.data_files.items() ) for table, files in iter_data_files: if not files: @@ -123,8 +121,8 @@ def schema(self) -> Dict[str, Dict]: return self._schema def query( - self, query: str, query_source: Optional[QuerySource] = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + self, query: str, query_source: QuerySource | None = None, **kwargs + ) -> pd.DataFrame | Any: """ Execute query string and return DataFrame of results. @@ -154,7 +152,7 @@ def query( ) return self._cached_query(query) - @lru_cache(maxsize=256) + @lru_cache(maxsize=256) # noqa: B019 def _cached_query(self, query: str) -> pd.DataFrame: iter_data_files = ( tqdm(self.data_files[query]) if self._progress else self.data_files[query] @@ -170,7 +168,7 @@ def query_with_results(self, query, **kwargs): return self.query(query, **kwargs), "OK" @property - def driver_queries(self) -> List[Dict[str, Any]]: + def driver_queries(self) -> list[dict[str, Any]]: """ Return dynamic queries available on connection to data. @@ -200,18 +198,15 @@ def driver_queries(self) -> List[Dict[str, Any]]: ] return [] - def _get_logfile_paths(self) -> Dict[str, List[Path]]: + def _get_logfile_paths(self) -> dict[str, list[Path]]: """Read files in data paths.""" - data_files: Dict[str, List[Path]] = defaultdict(list) + data_files: dict[str, list[Path]] = defaultdict(list) for input_path in (Path(path_str) for path_str in self._paths): - files = { - file.relative_to(input_path): file - for file in input_path.rglob("*.json") - } + files = {file.relative_to(input_path): file for file in input_path.rglob("*.json")} file_names = [valid_pyname(str(file.with_suffix(""))) for file in files] - path_files = dict(zip(file_names, files.values())) + path_files = dict(zip(file_names, files.values(), strict=False)) for file_name, file_path in path_files.items(): data_files[file_name].append(file_path) diff --git a/msticpy/data/drivers/mdatp_driver.py b/msticpy/data/drivers/mdatp_driver.py index 1ad175386..80a6faee7 100644 --- a/msticpy/data/drivers/mdatp_driver.py +++ b/msticpy/data/drivers/mdatp_driver.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """MS Defender/Defender 365 OData Driver class.""" + from __future__ import annotations import logging @@ -120,9 +121,7 @@ def __init__( else: logger.debug("Using cloud from configuration: %s", self.cloud) - logger.info( - "Selecting API configuration for environment: %s", self.data_environment - ) + logger.info("Selecting API configuration for environment: %s", self.data_environment) m365d_params: M365DConfiguration = _select_api( self.data_environment, self.cloud, @@ -296,7 +295,7 @@ def _select_api(data_environment: DataEnvironment, cloud: str) -> M365DConfigura "Please use Microsoft Graph Security Hunting API instead - " "provider name = 'M365DGraph'." ) - warnings.warn(warn_message, DeprecationWarning) + warnings.warn(warn_message, DeprecationWarning, stacklevel=2) # MDE Advanced Queries API logger.info("Using MDE Advanced Queries API (default)") diff --git a/msticpy/data/drivers/mordor_driver.py b/msticpy/data/drivers/mordor_driver.py index 4c02e42c1..de5a5aeb3 100644 --- a/msticpy/data/drivers/mordor_driver.py +++ b/msticpy/data/drivers/mordor_driver.py @@ -4,13 +4,15 @@ # license information. # -------------------------------------------------------------------------- """Mordor/OTRF Security datasets driver.""" + import json import pickle # nosec import zipfile from collections import defaultdict +from collections.abc import Generator, Iterable from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union +from typing import Any from zipfile import BadZipFile, ZipFile import attr @@ -42,8 +44,8 @@ _MTR_TECH_CAT_URI = "https://attack.mitre.org/techniques/{cat}/" # pylint: disable=invalid-name -MITRE_TECHNIQUES: Optional[pd.DataFrame] = None -MITRE_TACTICS: Optional[pd.DataFrame] = None +MITRE_TECHNIQUES: pd.DataFrame | None = None +MITRE_TACTICS: pd.DataFrame | None = None _MITRE_TECH_CACHE = "mitre_tech_cache.pkl" _MITRE_TACTICS_CACHE = "mitre_tact_cache.pkl" @@ -63,18 +65,14 @@ def __init__(self, **kwargs): self.has_driver_queries = True self.mitre_techniques: pd.DataFrame self.mitre_tactics: pd.DataFrame - self.mordor_data: Dict[str, MordorEntry] - self.mdr_idx_tech: Dict[str, Set[str]] - self.mdr_idx_tact: Dict[str, Set[str]] - self._driver_queries: List[Dict[str, Any]] = [] + self.mordor_data: dict[str, MordorEntry] + self.mdr_idx_tech: dict[str, set[str]] + self.mdr_idx_tact: dict[str, set[str]] + self._driver_queries: list[dict[str, Any]] = [] mdr_settings = get_config("DataProviders.Mordor", {}) - self.use_cached = kwargs.pop( - "used_cached", mdr_settings.get("used_cached", True) - ) - self.save_folder = kwargs.pop( - "save_folder", mdr_settings.get("save_folder", ".") - ) + self.use_cached = kwargs.pop("used_cached", mdr_settings.get("used_cached", True)) + self.save_folder = kwargs.pop("save_folder", mdr_settings.get("save_folder", ".")) self.save_folder = _resolve_cache_folder(self.save_folder) self.silent = kwargs.pop("silent", False) @@ -82,7 +80,7 @@ def __init__(self, **kwargs): # pylint: disable=global-statement - def connect(self, connection_str: Optional[str] = None, **kwargs): + def connect(self, connection_str: str | None = None, **kwargs): """ Connect to data source. @@ -97,9 +95,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): print("Retrieving Mitre data...") if MITRE_TECHNIQUES is None or MITRE_TACTICS is None: - MITRE_TECHNIQUES, MITRE_TACTICS = _get_mitre_categories( - cache_folder=cache_folder - ) + MITRE_TECHNIQUES, MITRE_TACTICS = _get_mitre_categories(cache_folder=cache_folder) self.mitre_techniques = MITRE_TECHNIQUES self.mitre_tactics = MITRE_TACTICS @@ -123,7 +119,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): def query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute query string and return DataFrame of results. @@ -164,7 +160,7 @@ def query( return "Could not convert result to a DataFrame." return result_df - def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, Any]: """ Execute query string and return DataFrame plus native results. @@ -185,7 +181,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: return pd.DataFrame(), result @property - def driver_queries(self) -> Iterable[Dict[str, Any]]: + def driver_queries(self) -> Iterable[dict[str, Any]]: """ Return generator of Mordor query definitions. @@ -211,12 +207,10 @@ def _get_driver_queries(self): ) tactics = ", ".join( - f"{tac[0]}: {tac[1]}" - for att in mitre_data - for tac in att.tactics_full + f"{tac[0]}: {tac[1]}" for att in mitre_data for tac in att.tactics_full ) - doc_string: List[str] = [ + doc_string: list[str] = [ f"{mdr_item.title}", "", "Notes", @@ -262,9 +256,7 @@ def search_queries(self, search: str) -> Iterable[str]: matches = [] for mdr_id in search_mdr_data(self.mordor_data, terms=search): for file_path in self.mordor_data[mdr_id].get_file_paths(): - matches.append( - f"{file_path['qry_path']} ({self.mordor_data[mdr_id].title})" - ) + matches.append(f"{file_path['qry_path']} ({self.mordor_data[mdr_id].title})") return matches @@ -287,10 +279,10 @@ class MitreAttack: def __init__( self, - attack: Dict[str, Any] = None, + attack: dict[str, Any] = None, technique: str = None, sub_technique: str = None, - tactics: List[str] = None, + tactics: list[str] = None, ): """ Create instance of MitreAttack. @@ -308,17 +300,15 @@ def __init__( """ if attack is None and (technique is None and tactics is None): - raise TypeError( - "Either 'attack' or 'technique' and 'tactics' must be specified." - ) + raise TypeError("Either 'attack' or 'technique' and 'tactics' must be specified.") self.technique = attack.get("technique") if attack else technique self.sub_technique = attack.get("sub-technique") if attack else sub_technique - self.tactics = attack.get("tactics") if attack else tactics # type: ignore + self.tactics = attack.get("tactics") if attack else tactics self._technique_name = None self._technique_desc = None self._technique_uri = None - self._tactics_full: List[Tuple[str, str, str, str]] = [] + self._tactics_full: list[tuple[str, str, str, str]] = [] def __repr__(self) -> str: """ @@ -339,7 +329,7 @@ def __repr__(self) -> str: ) @property - def technique_name(self) -> Optional[str]: + def technique_name(self) -> str | None: """ Return Mitre Technique full name. @@ -350,8 +340,7 @@ def technique_name(self) -> Optional[str]: """ if ( - not self._technique_name - and self.technique in MITRE_TECHNIQUES.index # type: ignore[union-attr] + not self._technique_name and self.technique in MITRE_TECHNIQUES.index # type: ignore[union-attr] ): self._technique_name = MITRE_TECHNIQUES.loc[ # type: ignore[union-attr] self.technique @@ -359,7 +348,7 @@ def technique_name(self) -> Optional[str]: return self._technique_name @property - def technique_desc(self) -> Optional[str]: + def technique_desc(self) -> str | None: """ Return Mitre technique description. @@ -370,8 +359,7 @@ def technique_desc(self) -> Optional[str]: """ if ( - not self._technique_desc - and self.technique in MITRE_TECHNIQUES.index # type: ignore[union-attr] + not self._technique_desc and self.technique in MITRE_TECHNIQUES.index # type: ignore[union-attr] ): self._technique_desc = MITRE_TECHNIQUES.loc[ # type: ignore self.technique @@ -392,7 +380,7 @@ def technique_uri(self) -> str: return self.MTR_TECH_URI.format(technique_id=self.technique) @property - def tactics_full(self) -> List[Tuple[str, str, str, str]]: + def tactics_full(self) -> list[tuple[str, str, str, str]]: """ Return full listing of Mitre tactics. @@ -410,9 +398,7 @@ def tactics_full(self) -> List[Tuple[str, str, str, str]]: tactic_name = MITRE_TACTICS.loc[tactic].Name # type: ignore[union-attr] tactic_desc = MITRE_TACTICS.loc[tactic].Description # type: ignore[union-attr] tactic_uri = self.MTR_TAC_URI.format(tactic_id=tactic) - self._tactics_full.append( - (tactic, tactic_name, tactic_desc, tactic_uri) - ) + self._tactics_full.append((tactic, tactic_name, tactic_desc, tactic_uri)) return self._tactics_full @@ -452,20 +438,20 @@ class MordorEntry: type: str creation_date: datetime = attr.ib(converter=_to_datetime) modification_date: datetime = attr.ib(converter=_to_datetime) - contributors: List[str] = attr.Factory(list) - author: Optional[str] = None - platform: Optional[str] = None - description: Optional[str] = None - tags: List[str] = attr.Factory(list) - files: List[Dict[str, Any]] = attr.Factory(list) - datasets: List[Dict[str, Any]] = attr.Factory(list) - attack_mappings: List[Dict[str, Any]] = attr.Factory(list) - notebooks: List[Dict[str, str]] = attr.Factory(list) - simulation: Dict[str, Any] = attr.Factory(dict) - references: List[Any] = attr.Factory(list) - _rel_file_paths: List[Dict[str, Any]] = attr.Factory(list) - - def get_notebooks(self) -> List[Tuple[str, str, str]]: + contributors: list[str] = attr.Factory(list) + author: str | None = None + platform: str | None = None + description: str | None = None + tags: list[str] = attr.Factory(list) + files: list[dict[str, Any]] = attr.Factory(list) + datasets: list[dict[str, Any]] = attr.Factory(list) + attack_mappings: list[dict[str, Any]] = attr.Factory(list) + notebooks: list[dict[str, str]] = attr.Factory(list) + simulation: dict[str, Any] = attr.Factory(dict) + references: list[Any] = attr.Factory(list) + _rel_file_paths: list[dict[str, Any]] = attr.Factory(list) + + def get_notebooks(self) -> list[tuple[str, str, str]]: """ Return the list of notebooks for the dataset. @@ -480,7 +466,7 @@ def get_notebooks(self) -> List[Tuple[str, str, str]]: for nbk in self.notebooks ] - def get_attacks(self) -> List[MitreAttack]: + def get_attacks(self) -> list[MitreAttack]: """ Return list of Mitre attack classifications. @@ -492,7 +478,7 @@ def get_attacks(self) -> List[MitreAttack]: """ return [MitreAttack(attack=attack) for attack in self.attack_mappings] - def get_file_paths(self) -> List[Dict[str, str]]: + def get_file_paths(self) -> list[dict[str, str]]: """ Return list of data file links. @@ -584,9 +570,9 @@ def _get_mdr_file(gh_file): def _create_mdr_metadata_cache(): - md_metadata: Dict[str, MordorEntry] = {} + md_metadata: dict[str, MordorEntry] = {} - def _get_mdr_metadata(cache_folder: Optional[str] = None): + def _get_mdr_metadata(cache_folder: str | None = None): nonlocal md_metadata if not md_metadata: md_metadata = _fetch_mdr_metadata(cache_folder=cache_folder) @@ -603,7 +589,7 @@ def _get_mdr_metadata(cache_folder: Optional[str] = None): # pylint: disable=global-statement -def _fetch_mdr_metadata(cache_folder: Optional[str] = None) -> Dict[str, MordorEntry]: +def _fetch_mdr_metadata(cache_folder: str | None = None) -> dict[str, MordorEntry]: """ Return full metadata for Mordor datasets. @@ -622,19 +608,15 @@ def _fetch_mdr_metadata(cache_folder: Optional[str] = None) -> Dict[str, MordorE if MITRE_TECHNIQUES is None or MITRE_TACTICS is None: MITRE_TECHNIQUES, MITRE_TACTICS = _get_mitre_categories() - md_metadata: Dict[str, MordorEntry] = {} + md_metadata: dict[str, MordorEntry] = {} md_cached_metadata = _read_mordor_cache(cache_folder) mdr_md_paths = list(get_mdr_data_paths("metadata")) - for filename in tqdm( - mdr_md_paths, unit=" files", desc="Downloading Mordor metadata" - ): + for filename in tqdm(mdr_md_paths, unit=" files", desc="Downloading Mordor metadata"): cache_valid = False if filename in md_cached_metadata: metadata_doc = md_cached_metadata[filename] - last_timestamp = pd.Timestamp( - metadata_doc.get(_LAST_UPDATE_KEY, _DEFAULT_TS) - ) + last_timestamp = pd.Timestamp(metadata_doc.get(_LAST_UPDATE_KEY, _DEFAULT_TS)) cache_valid = (pd.Timestamp.now(tz=timezone.utc) - last_timestamp).days < 30 if not cache_valid: @@ -643,9 +625,7 @@ def _fetch_mdr_metadata(cache_folder: Optional[str] = None) -> Dict[str, MordorE metadata_doc = yaml.safe_load(gh_file_content) except yaml.error.YAMLError: continue - metadata_doc[_LAST_UPDATE_KEY] = pd.Timestamp.now( - tz=timezone.utc - ).isoformat() + metadata_doc[_LAST_UPDATE_KEY] = pd.Timestamp.now(tz=timezone.utc).isoformat() md_cached_metadata[filename] = metadata_doc doc_id = metadata_doc.get("id") mdr_entry = metadata_doc.copy() @@ -659,9 +639,9 @@ def _fetch_mdr_metadata(cache_folder: Optional[str] = None) -> Dict[str, MordorE # pylint: enable=global-statement -def _read_mordor_cache(cache_folder) -> Dict[str, Any]: +def _read_mordor_cache(cache_folder) -> dict[str, Any]: """Return dictionary of cached metadata if cached_folder is a valid path.""" - md_cached_metadata: Dict[str, Any] = {} + md_cached_metadata: dict[str, Any] = {} mordor_cache = Path(cache_folder).joinpath(_MORDOR_CACHE) if _valid_cache(mordor_cache): try: @@ -680,8 +660,8 @@ def _write_mordor_cache(md_cached_metadata, cache_folder): def _build_mdr_indexes( - mdr_metadata: Dict[str, MordorEntry], -) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: + mdr_metadata: dict[str, MordorEntry], +) -> tuple[dict[str, set[str]], dict[str, set[str]]]: """ Return dictionaries mapping Mitre items to Mordor datasets. @@ -818,8 +798,8 @@ def _extract_zip_file_to_df( def search_mdr_data( - mdr_data: Dict[str, MordorEntry], terms: str = None, subset: Iterable[str] = None -) -> Set[str]: + mdr_data: dict[str, MordorEntry], terms: str = None, subset: Iterable[str] = None +) -> set[str]: """ Return IDs for items matching terms. @@ -850,7 +830,7 @@ def search_mdr_data( logic = "AND" else: search_terms = [terms] - results: Set[str] = set() + results: set[str] = set() for search_idx, term in enumerate(search_terms): item_results = set() for md_id, item in mdr_data.items(): @@ -895,8 +875,8 @@ def _reshape_mitre_df(data): def _get_mitre_categories( - cache_folder: Optional[str] = None, -) -> Tuple[pd.DataFrame, pd.DataFrame]: + cache_folder: str | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame]: """ Download and return Mitre techniques and tactics. diff --git a/msticpy/data/drivers/odata_driver.py b/msticpy/data/drivers/odata_driver.py index d4aa1b6da..89e729c79 100644 --- a/msticpy/data/drivers/odata_driver.py +++ b/msticpy/data/drivers/odata_driver.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """OData Driver class.""" + from __future__ import annotations import abc @@ -178,9 +179,7 @@ def connect( setting for setting in ("tenant_id", "client_id") if setting not in cs_dict ] auth_present: bool = ( - "username" in cs_dict - or "client_secret" in cs_dict - or "certificate" in cs_dict + "username" in cs_dict or "client_secret" in cs_dict or "certificate" in cs_dict ) if missing_settings: logger.error("Missing required connection parameters: %s", missing_settings) @@ -335,9 +334,7 @@ def _get_token_standard_auth( "Token acquisition failed: %s", json_response.get("error_description", "Unknown error"), ) - err_msg = ( - f"Could not obtain access token - {json_response['error_description']}" - ) + err_msg = f"Could not obtain access token - {json_response['error_description']}" raise MsticpyConnectionError(err_msg) logger.info("Successfully obtained access token via client secret") @@ -360,9 +357,7 @@ def _get_token_delegate_auth( ) logger.debug("Authority URL: %s", authority) logger.debug("Scopes: %s", self.scopes) - logger.info( - "Initializing MSAL delegated auth for user: %s", cs_dict["username"] - ) + logger.info("Initializing MSAL delegated auth for user: %s", cs_dict["username"]) self.msal_auth = MSALDelegatedAuth( client_id=cs_dict["client_id"], @@ -466,9 +461,7 @@ def _check_response_errors(response: httpx.Response) -> None: logger.warning("Response error: %s", response.json()["error"]["message"]) if response.status_code == httpx.codes.UNAUTHORIZED: logger.error("Authentication failed - status code 401") - err_msg: str = ( - "Authentication failed - possible timeout. Please re-connect." - ) + err_msg: str = "Authentication failed - possible timeout. Please re-connect." raise ConnectionRefusedError(err_msg) # Raise an exception to handle hitting API limits if response.status_code == httpx.codes.TOO_MANY_REQUESTS: @@ -540,10 +533,10 @@ def _map_config_dict_name(config_dict: dict[str, str]) -> dict[str, str]: """Map configuration parameter names to expected values.""" logger.debug("Mapping configuration dictionary names") mapped_dict: dict[str, str] = config_dict.copy() - for provided_name in config_dict: + for provided_name, mapped_name in config_dict.items(): for req_name, alternates in _CONFIG_NAME_MAP.items(): if provided_name.casefold() in alternates: - mapped_dict[req_name] = config_dict[provided_name] + mapped_dict[req_name] = mapped_name logger.debug("Mapped '%s' to '%s'", provided_name, req_name) break return mapped_dict @@ -555,13 +548,9 @@ def _get_driver_settings( instance: str | None = None, ) -> dict[str, str]: """Try to retrieve config settings for OAuth drivers.""" - logger.debug( - "Getting driver settings for: %s (instance: %s)", config_name, instance - ) + logger.debug("Getting driver settings for: %s (instance: %s)", config_name, instance) config_key: str = ( - f"{config_name}-{instance}" - if instance and instance != "Default" - else config_name + f"{config_name}-{instance}" if instance and instance != "Default" else config_name ) drv_config: ProviderSettings | None = get_provider_settings("DataProviders").get( config_key, diff --git a/msticpy/data/drivers/prismacloud_driver.py b/msticpy/data/drivers/prismacloud_driver.py index 1ffcc5fe7..1a9154784 100644 --- a/msticpy/data/drivers/prismacloud_driver.py +++ b/msticpy/data/drivers/prismacloud_driver.py @@ -9,15 +9,18 @@ __author__ = "Rajamani R" +import json import logging from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast -import json + import httpx import pandas as pd + from msticpy.common.exceptions import MsticpyConnectionError, MsticpyUserError -from .driver_base import DriverBase -from ..core.query_store import QuerySource, QueryStore + from ...common.provider_settings import get_provider_settings +from ..core.query_store import QuerySource, QueryStore +from .driver_base import DriverBase if TYPE_CHECKING: from collections.abc import Callable @@ -157,9 +160,7 @@ class PrismaCloudDriver(DriverBase): # pylint: disable=R0902 CONFIG_NAME: ClassVar[str] = "Prismacloud" - def __init__( - self, **kwargs: DriverConfig - ) -> None: # pylint: disable=too-many-locals + def __init__(self, **kwargs: DriverConfig) -> None: # pylint: disable=too-many-locals """ Initialize the Prisma Cloud Driver and set up the HTTP client. @@ -202,9 +203,7 @@ def __init__( # preference 1 as argument , preference 2 from config file , third default value if not kwargs.get("base_url"): - self.base_url = ( - cast(str, self.config.get("base_url")) or BASE_URL_API - ) # type: ignore[assignment] + self.base_url = cast(str, self.config.get("base_url")) or BASE_URL_API else: self.base_url = kwargs.get("base_url", BASE_URL_API) # type: ignore[assignment] self.debug: bool = bool(kwargs.get("debug", False)) @@ -244,9 +243,7 @@ def __init__( self.queries_loaded: bool = False @staticmethod - def _get_driver_settings( - config_name: str, instance: str | None = None - ) -> dict[str, str]: + def _get_driver_settings(config_name: str, instance: str | None = None) -> dict[str, str]: """ Retrieve Prisma Cloud settings from MSTICPY configuration. @@ -339,9 +336,7 @@ def connect( # type: ignore[override] if connection_str: username = username or connection_str.split(":")[0] password = ( - password or connection_str.split(":")[1] - if ":" in connection_str - else None + password or connection_str.split(":")[1] if ":" in connection_str else None ) if not username or not password: username = self.config.get("username") @@ -368,9 +363,7 @@ def connect( # type: ignore[override] self._loaded = True logger.info("Prisma Cloud connection successful") if "X-Redlock-Auth" not in self.client.headers: - logger.debug( - "X-Redlock-Auth not in self.client.headers did not match" - ) + logger.debug("X-Redlock-Auth not in self.client.headers did not match") return self logger.error("Login failed: %s", result.get("message", "Unknown error")) msg = f"Login failed: {result.get('message', 'Unknown error')}" diff --git a/msticpy/data/drivers/resource_graph_driver.py b/msticpy/data/drivers/resource_graph_driver.py index dfb6c3ad5..82abdb051 100644 --- a/msticpy/data/drivers/resource_graph_driver.py +++ b/msticpy/data/drivers/resource_graph_driver.py @@ -4,8 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Azure Resource Graph Driver class.""" + +from __future__ import annotations + import warnings -from typing import Any, Tuple, Union +from typing import Any import pandas as pd @@ -103,7 +106,7 @@ def connect(self, connection_str: str = None, **kwargs): def query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute Resource Graph query and retrieve results. @@ -133,7 +136,7 @@ def query( return result - def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, Any]: """ Execute query string and return DataFrame of results. @@ -162,7 +165,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: request_options = QueryRequestOptions( top=top, - result_format=ResultFormat.OBJECT_ARRAY, # type: ignore + result_format=ResultFormat.OBJECT_ARRAY, ) request = QueryRequest( @@ -185,6 +188,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: "Some resources may be missing from the results. " "To rewrite the query and enable paging, " "see the docs for an example: https://aka.ms/arg-results-truncated", + stacklevel=2, ) return pd.json_normalize(response.data), response diff --git a/msticpy/data/drivers/security_graph_driver.py b/msticpy/data/drivers/security_graph_driver.py index 07c0196b9..0eaf94eee 100644 --- a/msticpy/data/drivers/security_graph_driver.py +++ b/msticpy/data/drivers/security_graph_driver.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Security Graph OData Driver class.""" -from typing import Any, Optional, Union + +from typing import Any import pandas as pd @@ -24,7 +25,7 @@ class SecurityGraphDriver(OData): CONFIG_NAME = "MicrosoftGraph" _ALT_CONFIG_NAMES = ["SecurityGraphApp"] - def __init__(self, connection_str: Optional[str] = None, **kwargs): + def __init__(self, connection_str: str | None = None, **kwargs): """ Instantiate MSGraph driver and optionally connect. @@ -53,8 +54,8 @@ def __init__(self, connection_str: Optional[str] = None, **kwargs): self.connect(connection_str) def query( - self, query: str, query_source: Optional[QuerySource] = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + self, query: str, query_source: QuerySource | None = None, **kwargs + ) -> pd.DataFrame | Any: """ Execute query string and return DataFrame of results. diff --git a/msticpy/data/drivers/sentinel_query_reader.py b/msticpy/data/drivers/sentinel_query_reader.py index e14a42945..7f5ab0d67 100644 --- a/msticpy/data/drivers/sentinel_query_reader.py +++ b/msticpy/data/drivers/sentinel_query_reader.py @@ -5,6 +5,8 @@ # -------------------------------------------------------------------------- """Github Sentinel Query repo import class and helpers.""" +from __future__ import annotations + import logging import os import re @@ -12,7 +14,6 @@ import zipfile from datetime import datetime from pathlib import Path -from typing import Optional import attr import httpx @@ -92,10 +93,8 @@ class SentinelQuery: def get_sentinel_queries_from_github( - git_url: Optional[ - str - ] = "https://github.com/Azure/Azure-Sentinel/archive/master.zip", - outputdir: Optional[str] = None, + git_url: str | None = "https://github.com/Azure/Azure-Sentinel/archive/master.zip", + outputdir: str | None = None, ) -> bool: r""" Download Microsoft Sentinel Github archive and extract detection and hunting queries. @@ -112,9 +111,7 @@ def get_sentinel_queries_from_github( """ if outputdir is None: - outputdir = str( - Path.joinpath(Path("~").expanduser(), ".msticpy", "Azure-Sentinel") - ) + outputdir = str(Path.joinpath(Path("~").expanduser(), ".msticpy", "Azure-Sentinel")) try: with httpx.stream("GET", git_url, follow_redirects=True) as response: # type: ignore @@ -124,7 +121,7 @@ def get_sentinel_queries_from_github( unit="iB", unit_scale=True, ) - repo_zip = Path.joinpath(Path(outputdir), "Azure-Sentinel.zip") # type: ignore + repo_zip = Path.joinpath(Path(outputdir), "Azure-Sentinel.zip") with open(repo_zip, "wb") as file: for data in response.iter_bytes(chunk_size=10000): progress_bar.update(len(data)) @@ -146,7 +143,10 @@ def get_sentinel_queries_from_github( return True except httpx.HTTPError as http_err: - warnings.warn(f"HTTP error occurred trying to download from Github: {http_err}") + warnings.warn( + f"HTTP error occurred trying to download from Github: {http_err}", + stacklevel=2, + ) return False @@ -203,9 +203,7 @@ def import_sentinel_queries(yaml_files: dict, query_type: str) -> list: ] -def _import_sentinel_query( - yaml_path: str, yaml_text: str, query_type: str -) -> SentinelQuery: +def _import_sentinel_query(yaml_path: str, yaml_text: str, query_type: str) -> SentinelQuery: """ Create a SentinelQuery attr object for a given yaml query. @@ -296,7 +294,7 @@ def _organize_query_list_by_folder(query_list: list) -> dict: queries_by_folder = {} for query in query_list: if query.folder_name == "": - warnings.warn(f"query {query} has no folder_name") + warnings.warn(f"query {query} has no folder_name", stacklevel=2) if query.folder_name not in queries_by_folder: queries_by_folder[query.folder_name] = [query] else: @@ -305,7 +303,7 @@ def _organize_query_list_by_folder(query_list: list) -> dict: return queries_by_folder -def _create_queryfile_metadata(folder_name: str) -> dict: # type: ignore +def _create_queryfile_metadata(folder_name: str) -> dict: """ Generate metadata section of the YAML file for the given folder_name. @@ -406,9 +404,7 @@ def write_to_yaml(query_list: list, query_type: str, output_folder: str) -> bool print(err) try: - query_text = yaml.safe_dump( - dict_to_write, encoding="utf-8", sort_keys=False - ) + query_text = yaml.safe_dump(dict_to_write, encoding="utf-8", sort_keys=False) except yaml.YAMLError as error: print(error) return False @@ -434,7 +430,7 @@ def write_to_yaml(query_list: list, query_type: str, output_folder: str) -> bool def download_and_write_sentinel_queries( - query_type: str, yaml_output_folder: str, github_outputdir: Optional[str] = None + query_type: str, yaml_output_folder: str, github_outputdir: str | None = None ): """ Download queries from GitHub and write out YAML files for the given query type. diff --git a/msticpy/data/drivers/splunk_driver.py b/msticpy/data/drivers/splunk_driver.py index d686dbb09..21487df08 100644 --- a/msticpy/data/drivers/splunk_driver.py +++ b/msticpy/data/drivers/splunk_driver.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Splunk Driver class.""" + import logging +from collections.abc import Iterable from datetime import datetime, timedelta, timezone from time import sleep -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any import jwt import pandas as pd @@ -87,7 +89,7 @@ class SplunkDriver(DriverBase): """Driver to connect and query from Splunk.""" _SPLUNK_REQD_ARGS = ["host"] - _CONNECT_DEFAULTS: Dict[str, Any] = {"port": "8089"} + _CONNECT_DEFAULTS: dict[str, Any] = {"port": "8089"} _TIME_FORMAT = '"%Y-%m-%d %H:%M:%S.%6N"' def __init__(self, **kwargs): @@ -116,7 +118,7 @@ def __init__(self, **kwargs): }, ) - def connect(self, connection_str: Optional[str] = None, **kwargs): + def connect(self, connection_str: str | None = None, **kwargs): """ Connect to Splunk via splunk-sdk. @@ -138,9 +140,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): """ cs_dict = self._get_connect_args(connection_str, **kwargs) - arg_dict = { - key: val for key, val in cs_dict.items() if key in SPLUNK_CONNECT_ARGS - } + arg_dict = {key: val for key, val in cs_dict.items() if key in SPLUNK_CONNECT_ARGS} # Replace to Splunk python sdk's parameter name of sp_client.connect() if arg_dict.get("bearer_token"): @@ -169,11 +169,9 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): self._connected = True print("Connected.") - def _get_connect_args( - self, connection_str: Optional[str], **kwargs - ) -> Dict[str, Any]: + def _get_connect_args(self, connection_str: str | None, **kwargs) -> dict[str, Any]: """Check and consolidate connection parameters.""" - cs_dict: Dict[str, Any] = self._CONNECT_DEFAULTS + cs_dict: dict[str, Any] = self._CONNECT_DEFAULTS # Fetch any config settings cs_dict.update(self._get_config_settings("Splunk")) # If a connection string - parse this and add to config @@ -181,10 +179,7 @@ def _get_connect_args( print("Credential loading from connection_str tied with ';'.") cs_items = connection_str.split(";") cs_dict.update( - { - cs_item.split("=")[0].strip(): cs_item.split("=")[1] - for cs_item in cs_items - } + {cs_item.split("=")[0].strip(): cs_item.split("=")[1] for cs_item in cs_items} ) elif kwargs: print("Credential loading from connect() method's args.") @@ -246,8 +241,8 @@ def _get_connect_args( return cs_dict def query( - self, query: str, query_source: Optional[QuerySource] = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + self, query: str, query_source: QuerySource | None = None, **kwargs + ) -> pd.DataFrame | Any: """ Execute splunk query and retrieve results via OneShot or async search mode. @@ -316,19 +311,15 @@ def query( else: # Set mode and initialize async job kwargs_normalsearch = {"exec_mode": "normal"} - query_job = self.service.jobs.create( - query, count=count, **kwargs_normalsearch - ) - resp_rows, reader = self._exec_async_search( - query_job, page_size, timeout=timeout - ) + query_job = self.service.jobs.create(query, count=count, **kwargs_normalsearch) + resp_rows, reader = self._exec_async_search(query_job, page_size, timeout=timeout) if len(resp_rows) == 0 or not resp_rows: print("Warning - query did not return any results.") return [row for row in reader if isinstance(row, sp_results.Message)] return pd.DataFrame(resp_rows) - def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, Any]: """ Execute query string and return DataFrame of results. @@ -347,7 +338,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: raise NotImplementedError(f"Not supported for {self.__class__.__name__}") @property - def service_queries(self) -> Tuple[Dict[str, str], str]: + def service_queries(self) -> tuple[dict[str, str], str]: """ Return dynamic queries available on connection to service. @@ -369,7 +360,7 @@ def service_queries(self) -> Tuple[Dict[str, str], str]: return {}, "SavedSearches" @property - def driver_queries(self) -> Iterable[Dict[str, Any]]: + def driver_queries(self) -> Iterable[dict[str, Any]]: """ Return dynamic queries available on connection to service. @@ -438,9 +429,9 @@ def _retrieve_job_status(query_job, progress_bar, prev_progress): "result_count": int(query_job["resultCount"]), } status = ( - "\r%(done_progress)03.1f%% %(scan_count)d scanned " - "%(event_count)d matched %(result_count)d results" - ) % stats + f"\r{stats['done_progress']:03.1f}% {stats['scan_count']:d} scanned " + f"{stats['event_count']:d} matched {stats['result_count']:d} results" + ) if prev_progress == 0: progress = stats["done_progress"] else: @@ -458,9 +449,7 @@ def _retrieve_job_status(query_job, progress_bar, prev_progress): def _retrieve_results(query_job, offset, page_size): """Retrieve the results of a job, decode and return them.""" # Retrieving all the results by paginate - result_count = int( - query_job["resultCount"] - ) # Number of results this job returned + result_count = int(query_job["resultCount"]) # Number of results this job returned resp_rows = [] progress_bar_paginate = tqdm( @@ -488,7 +477,7 @@ def _retrieve_results(query_job, offset, page_size): return resp_rows, reader @property - def _saved_searches(self) -> Union[pd.DataFrame, Any]: + def _saved_searches(self) -> pd.DataFrame | Any: """ Return list of saved searches in dataframe. @@ -500,7 +489,7 @@ def _saved_searches(self) -> Union[pd.DataFrame, Any]: """ return self._get_saved_searches() if self.connected else None - def _get_saved_searches(self) -> Union[pd.DataFrame, Any]: + def _get_saved_searches(self) -> pd.DataFrame | Any: # sourcery skip: class-extract-method """ Return list of saved searches in dataframe. @@ -528,7 +517,7 @@ def _get_saved_searches(self) -> Union[pd.DataFrame, Any]: return out_df @property - def _fired_alerts(self) -> Union[pd.DataFrame, Any]: + def _fired_alerts(self) -> pd.DataFrame | Any: """ Return list of fired alerts in dataframe. @@ -540,7 +529,7 @@ def _fired_alerts(self) -> Union[pd.DataFrame, Any]: """ return self._get_fired_alerts() if self.connected else None - def _get_fired_alerts(self) -> Union[pd.DataFrame, Any]: + def _get_fired_alerts(self) -> pd.DataFrame | Any: """ Return list of fired alerts in dataframe. diff --git a/msticpy/data/drivers/sumologic_driver.py b/msticpy/data/drivers/sumologic_driver.py index 915417103..761117f12 100644 --- a/msticpy/data/drivers/sumologic_driver.py +++ b/msticpy/data/drivers/sumologic_driver.py @@ -4,11 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Sumologic Driver class.""" + import re import time from datetime import datetime, timedelta from timeit import default_timer as timer -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import httpx import pandas as pd @@ -30,8 +31,7 @@ SUMOLOGIC_CONNECT_ARGS = { "connection_str": ( - "(string) The url endpoint (the default is" - + " 'https://api.us2.sumologic.com/api')." + "(string) The url endpoint (the default is" + " 'https://api.us2.sumologic.com/api')." ), "accessid": ( "(string) The Sumologic accessid, which is used to " @@ -52,9 +52,7 @@ class SumologicDriver(DriverBase): """Driver to connect and query from Sumologic.""" _SUMOLOGIC_REQD_ARGS = ["connection_str", "accessid", "accesskey"] - _CONNECT_DEFAULTS: Dict[str, Any] = { - "connection_str": "https://api.us2.sumologic.com/api" - } + _CONNECT_DEFAULTS: dict[str, Any] = {"connection_str": "https://api.us2.sumologic.com/api"} _TIME_FORMAT = '"%Y-%m-%d %H:%M:%S.%6N"' _DEF_CHECKINTERVAL = 3 _DEF_TIMEOUT = 300 @@ -67,9 +65,7 @@ def __init__(self, **kwargs): self._connected = False self._debug = kwargs.get("debug", False) self.set_driver_property(DriverProps.PUBLIC_ATTRS, {"client": self.service}) - self.set_driver_property( - DriverProps.FORMATTERS, {"datetime": self._format_datetime} - ) + self.set_driver_property(DriverProps.FORMATTERS, {"datetime": self._format_datetime}) self.checkinterval = self._DEF_CHECKINTERVAL self.timeout = self._DEF_TIMEOUT @@ -95,9 +91,7 @@ def connect(self, connection_str: str = None, **kwargs): """ cs_dict = self._get_connect_args(connection_str, **kwargs) - arg_dict = { - key: val for key, val in cs_dict.items() if key in SUMOLOGIC_CONNECT_ARGS - } + arg_dict = {key: val for key, val in cs_dict.items() if key in SUMOLOGIC_CONNECT_ARGS} try: # https://github.com/SumoLogic/sumologic-python-sdk/blob/master/scripts/search-job.py self.service = SumoLogic( @@ -129,11 +123,9 @@ def connect(self, connection_str: str = None, **kwargs): self._connected = True print(f"connected with accessid {arg_dict['accessid']}") - def _get_connect_args( - self, connection_str: Optional[str], **kwargs - ) -> Dict[str, Any]: + def _get_connect_args(self, connection_str: str | None, **kwargs) -> dict[str, Any]: """Check and consolidate connection parameters.""" - cs_dict: Dict[str, Any] = self._CONNECT_DEFAULTS + cs_dict: dict[str, Any] = self._CONNECT_DEFAULTS # Fetch any config settings settings, cs_is_instance_name = self._get_sumologic_settings(connection_str) cs_dict.update(settings) @@ -162,7 +154,7 @@ def _get_connect_args( # pylint: disable=broad-except def _query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute Sumologic query and retrieve results. @@ -308,9 +300,7 @@ def _get_job_results_messages(self, searchjob, status, limit): except Exception as err: self._raise_qry_except(err, "search_job_messages", "to get job messages") - def _get_job_results_records( # noqa: MC0001 - self, searchjob, status, limit, verbosity - ): + def _get_job_results_records(self, searchjob, status, limit, verbosity): # Aggregated results, limit count = status["recordCount"] limit2 = None @@ -326,9 +316,7 @@ def _get_job_results_records( # noqa: MC0001 result = self.service.search_job_records(searchjob, limit=limit2) return result["records"] except Exception as err: - self._raise_qry_except( - err, "search_job_records", "to get search records" - ) + self._raise_qry_except(err, "search_job_records", "to get search records") else: # paging results # https://help.sumologic.com/APIs/Search-Job-API/About-the-Search-Job-API#query-parameters-2 @@ -344,9 +332,7 @@ def _get_job_results_records( # noqa: MC0001 else: job_limit2 = job_limit if verbosity >= 2: - print( - f"DEBUG: Paging {i * job_limit} / {count}, limit {job_limit2}" - ) + print(f"DEBUG: Paging {i * job_limit} / {count}, limit {job_limit2}") result = self.service.search_job_records( searchjob, offset=(i * job_limit), limit=job_limit2 ) @@ -376,7 +362,7 @@ def _get_job_results( # pylint: enable=inconsistent-return-statements @staticmethod - def _raise_qry_except(err: Exception, mssg: str, action: Optional[str] = None): + def _raise_qry_except(err: Exception, mssg: str, action: str | None = None): if isinstance(err, httpx.HTTPError): raise MsticpyConnectionError( f"Communication error connecting to Sumologic: {err}", @@ -412,9 +398,9 @@ def _get_time_params(self, **kwargs): return self._format_datetime(start), self._format_datetime(end) # pylint: disable=too-many-branches - def query( # noqa: MC0001 + def query( self, query: str, query_source: QuerySource = None, **kwargs - ) -> Union[pd.DataFrame, Any]: + ) -> pd.DataFrame | Any: """ Execute Sumologic query and retrieve results. @@ -485,7 +471,7 @@ def query( # noqa: MC0001 if verbosity >= 3: print("DEBUG: {results}") if normalize: - dataframe_res = pd.json_normalize(results) # type: ignore + dataframe_res = pd.json_normalize(results) else: dataframe_res = pd.DataFrame(results) @@ -527,7 +513,7 @@ def query( # noqa: MC0001 return dataframe_res.copy() - def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: + def query_with_results(self, query: str, **kwargs) -> tuple[pd.DataFrame, Any]: """ Execute query string and return DataFrame of results. @@ -555,7 +541,7 @@ def _format_datetime(date_time: datetime) -> str: @staticmethod def _get_sumologic_settings( instance_name: str = None, - ) -> Tuple[Dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """Get config from msticpyconfig.""" data_provs = get_provider_settings(config_section="DataProviders") sl_settings = { @@ -563,7 +549,7 @@ def _get_sumologic_settings( for name, settings in data_provs.items() if name.startswith("Sumologic") } - sumologic_settings: Optional[ProviderSettings] + sumologic_settings: ProviderSettings | None # Check if the connection string is an instance name sumologic_settings = sl_settings.get(f"Sumologic-{instance_name}") if sumologic_settings: diff --git a/msticpy/data/query_container.py b/msticpy/data/query_container.py deleted file mode 100644 index 7c51471c9..000000000 --- a/msticpy/data/query_container.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module query_container.py has moved. - -See :py:mod:`msticpy.data.core.query_container` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from .core.query_container import * - -WARN_MSSG = ( - "This module has moved to msticpy.data.core.query_container\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/query_defns.py b/msticpy/data/query_defns.py deleted file mode 100644 index 90afa144c..000000000 --- a/msticpy/data/query_defns.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module query_defns.py has moved. - -See :py:mod:`msticpy.data.core.query_defns` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from .core.query_defns import * - -WARN_MSSG = ( - "This module has moved to msticpy.data.core.query_defns\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/data/sql_to_kql.py b/msticpy/data/sql_to_kql.py index 72eb490b7..719819c3c 100644 --- a/msticpy/data/sql_to_kql.py +++ b/msticpy/data/sql_to_kql.py @@ -194,8 +194,8 @@ def sql_to_kql(sql: str, target_tables: dict[str, str] | None = None) -> str: # replace table names if target_tables: - for table, target in target_tables.items(): - sql = sql.replace(table, target) + for table, target_name in target_tables.items(): + sql = sql.replace(table, target_name) parsed_sql = parse(sql) query_lines = _parse_query(parsed_sql) @@ -230,17 +230,13 @@ def _parse_query(parsed_sql: dict[str, Any]) -> list[str]: distinct_select: list[dict[str, Any]] = [] if SELECT_DISTINCT in parsed_sql: distinct_select.extend(parsed_sql[SELECT_DISTINCT]) - _process_select( - parsed_sql[SELECT_DISTINCT], parsed_sql[SELECT_DISTINCT], query_lines - ) + _process_select(parsed_sql[SELECT_DISTINCT], parsed_sql[SELECT_DISTINCT], query_lines) if SELECT in parsed_sql: _process_select(parsed_sql[SELECT], parsed_sql[SELECT], query_lines) if ORDER_BY in parsed_sql: query_lines.append(f"| order by {_create_order_by(parsed_sql[ORDER_BY])}") if distinct_select: - query_lines.append( - f"| distinct {', '.join(_create_distinct_list(distinct_select))}" - ) + query_lines.append(f"| distinct {', '.join(_create_distinct_list(distinct_select))}") if LIMIT in parsed_sql: query_lines.append(f"| limit {parsed_sql[LIMIT]}") if UNION in parsed_sql: @@ -355,9 +351,7 @@ def _get_expr_value(expr_val: Any) -> Any: def _process_group_by(parsed_sql: dict[str, Any], query_lines: list[str]) -> None: """Process GROUP BY clause.""" group_by_expr = parsed_sql[GROUP_BY] - group_by_expr = ( - group_by_expr if isinstance(group_by_expr, list) else [group_by_expr] - ) + group_by_expr = group_by_expr if isinstance(group_by_expr, list) else [group_by_expr] by_clause = ", ".join(val["value"] for val in group_by_expr if val.get("value")) expr_list = parsed_sql.get(SELECT, parsed_sql.get(SELECT_DISTINCT, [])) @@ -393,13 +387,9 @@ def _parse_expression(expression: Any) -> str: # noqa: PLR0911 return f"dcount({func_arg})" if AND in expression: - return "\n and ".join( - [f"({_parse_expression(expr)})" for expr in expression[AND]] - ) + return "\n and ".join([f"({_parse_expression(expr)})" for expr in expression[AND]]) if OR in expression: - return "\n or ".join( - [f"({_parse_expression(expr)})" for expr in expression[OR]] - ) + return "\n or ".join([f"({_parse_expression(expr)})" for expr in expression[OR]]) if NOT in expression: return f" not ({_parse_expression(expression[NOT])})" if BETWEEN in expression: diff --git a/msticpy/data/storage/azure_blob_storage.py b/msticpy/data/storage/azure_blob_storage.py index f1e2f6df1..9d163f7f5 100644 --- a/msticpy/data/storage/azure_blob_storage.py +++ b/msticpy/data/storage/azure_blob_storage.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Uses the Azure Python SDK to interact with Azure Blob Storage.""" + import datetime -from typing import Any, List, Optional +from typing import Any import pandas as pd from azure.common.exceptions import CloudError @@ -37,14 +38,14 @@ def __init__( self.connected = False self.abs_site = f"{abs_name}.blob.core.windows.net" self.connection_string = abs_connection_string - self.credentials: Optional[AzCredentials] = None - self.abs_client: Optional[BlobServiceClient] = None + self.credentials: AzCredentials | None = None + self.abs_client: BlobServiceClient | None = None if connect: self.connect() def connect( self, - auth_methods: List = None, + auth_methods: list = None, silent: bool = False, ): """Authenticate with the SDK.""" @@ -54,9 +55,7 @@ def connect( if not self.connection_string: self.abs_client = BlobServiceClient(self.abs_site, self.credentials.modern) else: - self.abs_client = BlobServiceClient.from_connection_string( - self.connection_string - ) + self.abs_client = BlobServiceClient.from_connection_string(self.connection_string) if not self.abs_client: raise CloudError("Could not create a Blob Storage client.") self.connected = True @@ -70,9 +69,7 @@ def containers(self) -> pd.DataFrame: "Unable to connect check the Azure Blob Store account name" ) from err return ( - _parse_returned_items( # type:ignore - container_list, remove_list=["lease", "encryption_scope"] - ) + _parse_returned_items(container_list, remove_list=["lease", "encryption_scope"]) if container_list else None ) @@ -96,13 +93,13 @@ def create_container(self, container_name: str, **kwargs) -> pd.DataFrame: try: new_container = self.abs_client.create_container( # type: ignore container_name, **kwargs - ) # type:ignore + ) except ResourceExistsError as err: raise CloudError(f"Container {container_name} already exists.") from err properties = new_container.get_container_properties() return _parse_returned_items([properties], ["encryption_scope", "lease"]) - def blobs(self, container_name: str) -> Optional[pd.DataFrame]: + def blobs(self, container_name: str) -> pd.DataFrame | None: """ Get a list of blobs in a container. @@ -119,7 +116,7 @@ def blobs(self, container_name: str) -> Optional[pd.DataFrame]: """ container_client = self.abs_client.get_container_client( # type: ignore[union-attr] container_name - ) # type: ignore + ) blobs = list(container_client.list_blobs()) return _parse_returned_items(blobs) if blobs else None @@ -153,9 +150,7 @@ def upload_to_blob( if not upload["error_code"]: print("Upload complete") else: - raise CloudError( - f"There was a problem uploading the blob: {upload['error_code']}" - ) + raise CloudError(f"There was a problem uploading the blob: {upload['error_code']}") return True def get_blob(self, container_name: str, blob_name: str) -> bytes: diff --git a/msticpy/data/uploaders/loganalytics_uploader.py b/msticpy/data/uploaders/loganalytics_uploader.py index f392c531c..a0f79ff44 100644 --- a/msticpy/data/uploaders/loganalytics_uploader.py +++ b/msticpy/data/uploaders/loganalytics_uploader.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """LogAnalytics Uploader class.""" + import base64 import datetime import hashlib @@ -112,13 +113,7 @@ def _post_data(self, body: str, table_name: str): signature = self._build_signature( rfc1123date, content_length, "POST", content_type, resource ) - uri = ( - "https://" - + self.workspace - + self.ops_loc - + resource - + "?api-version=2016-04-01" - ) + uri = "https://" + self.workspace + self.ops_loc + resource + "?api-version=2016-04-01" headers = { "content-type": content_type, "Authorization": signature, @@ -171,9 +166,7 @@ def upload_df(self, data: pd.DataFrame, table_name: Any, **kwargs): if self._debug: print(f"Upload to {table_name} complete") - def upload_file( - self, file_path: str, table_name: str = None, delim: str = ",", **kwargs - ): + def upload_file(self, file_path: str, table_name: str = None, delim: str = ",", **kwargs): """ Upload a seperated value file to Log Analytics. diff --git a/msticpy/data/uploaders/splunk_uploader.py b/msticpy/data/uploaders/splunk_uploader.py index f0f59a844..5187531c4 100644 --- a/msticpy/data/uploaders/splunk_uploader.py +++ b/msticpy/data/uploaders/splunk_uploader.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Splunk Uploader class.""" + import logging from pathlib import Path -from typing import Any, Optional +from typing import Any import pandas as pd from pandas.errors import ParserError @@ -105,15 +106,15 @@ def _post_data( source_types = [] for row in data.iterrows(): if source_type == "json": - data = row[1].to_json() # type: ignore + data = row[1].to_json() elif source_type == "csv": - data = row[1].to_csv() # type: ignore + data = row[1].to_csv() else: - data = row[1].to_string() # type: ignore + data = row[1].to_string() try: - data.encode(encoding="latin-1") # type: ignore + data.encode(encoding="latin-1") except UnicodeEncodeError: - data = data.encode(encoding="utf-8") # type: ignore + data = data.encode(encoding="utf-8") index.submit(data, sourcetype=source_type, host=host) source_types.append(source_type) progress.update(1) @@ -122,13 +123,13 @@ def _post_data( print(f"Upload complete: Splunk sourcetype = {source_types}") # pylint: disable=arguments-differ - def upload_df( # type: ignore + def upload_df( self, data: pd.DataFrame, - table_name: Optional[str] = None, - index_name: Optional[str] = None, + table_name: str | None = None, + index_name: str | None = None, create_index: bool = False, - source_type: Optional[str] = None, + source_type: str | None = None, **kwargs, ): """ @@ -175,14 +176,14 @@ def upload_df( # type: ignore create_index=create_index, ) - def upload_file( # type: ignore + def upload_file( self, file_path: str, - table_name: Optional[str] = None, + table_name: str | None = None, delim: str = ",", - index_name: Optional[str] = None, + index_name: str | None = None, create_index: bool = False, - source_type: Optional[str] = None, + source_type: str | None = None, **kwargs, ): """ @@ -236,14 +237,14 @@ def upload_file( # type: ignore create_index=create_index, ) - def upload_folder( # type: ignore + def upload_folder( self, folder_path: str, - table_name: Optional[str] = None, + table_name: str | None = None, delim: str = ",", - index_name: Optional[str] = None, + index_name: str | None = None, create_index=False, - source_type: Optional[str] = None, + source_type: str | None = None, **kwargs, ): """ diff --git a/msticpy/data/uploaders/uploader_base.py b/msticpy/data/uploaders/uploader_base.py index 75ed05616..414b7f800 100644 --- a/msticpy/data/uploaders/uploader_base.py +++ b/msticpy/data/uploaders/uploader_base.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Data uploader base class.""" + import abc from abc import ABC diff --git a/msticpy/datamodel/entities/__init__.py b/msticpy/datamodel/entities/__init__.py index 079988265..15a65e001 100644 --- a/msticpy/datamodel/entities/__init__.py +++ b/msticpy/datamodel/entities/__init__.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Entity sub-package.""" + import difflib -from typing import List from .account import Account from .alert import Alert @@ -116,9 +116,7 @@ class Machine(Host): def find_entity(entity): """Find entity name.""" entity_cf = entity.casefold() - entity_cls_dict = { - cls.__name__.casefold(): cls for cls in Entity.ENTITY_NAME_MAP.values() - } + entity_cls_dict = {cls.__name__.casefold(): cls for cls in Entity.ENTITY_NAME_MAP.values()} if entity_cf in Entity.ENTITY_NAME_MAP: print(f"Match found '{Entity.ENTITY_NAME_MAP[entity].__name__}'") return Entity.ENTITY_NAME_MAP[entity] @@ -144,11 +142,11 @@ def find_entity(entity): return None -def list_entities() -> List[str]: +def list_entities() -> list[str]: """List entities.""" return sorted([cls.__name__ for cls in set(Entity.ENTITY_NAME_MAP.values())]) -def entity_classes() -> List[type]: +def entity_classes() -> list[type]: """Return a list of all entity classes.""" return list(set(Entity.ENTITY_NAME_MAP.values())) diff --git a/msticpy/datamodel/entities/account.py b/msticpy/datamodel/entities/account.py index c98f25b8a..77868c456 100644 --- a/msticpy/datamodel/entities/account.py +++ b/msticpy/datamodel/entities/account.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Account Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.data_types import SplitProperty @@ -90,18 +92,18 @@ def __init__( kw arguments. """ - self.Name: Optional[str] = None - self.NTDomain: Optional[str] = None - self.UPNSuffix: Optional[str] = None - self.Host: Optional[Host] = None - self.LogonId: Optional[str] = None - self.Sid: Optional[str] = None - self.AadTenantId: Optional[str] = None - self._AadUserId: Optional[str] = None - self.PUID: Optional[str] = None + self.Name: str | None = None + self.NTDomain: str | None = None + self.UPNSuffix: str | None = None + self.Host: Host | None = None + self.LogonId: str | None = None + self.Sid: str | None = None + self.AadTenantId: str | None = None + self._AadUserId: str | None = None + self.PUID: str | None = None self.IsDomainJoined: bool = False - self.DisplayName: Optional[str] = None - self.ObjectGuid: Optional[str] = None + self.DisplayName: str | None = None + self.ObjectGuid: str | None = None if "Upn" in kwargs: self.Upn = kwargs.pop("Upn") if "AadUserId" in kwargs: @@ -121,7 +123,7 @@ def name_str(self) -> str: return self.Name or self.DisplayName or "Unknown Account" @property - def AadUserId(self) -> Optional[str]: # noqa: N802 + def AadUserId(self) -> str | None: # noqa: N802 """Return the Azure AD user ID or the ObjectGuid.""" return self._AadUserId or self.ObjectGuid @@ -150,33 +152,21 @@ def _create_from_event(self, src_event, role): if role == "subject" and "SubjectUserName" in src_event: self.Name = src_event["SubjectUserName"] self.NTDomain = ( - src_event["SubjectUserDomain"] - if "SubjectUserDomain" in src_event - else None - ) - self.Sid = ( - src_event["SubjectUserSid"] if "SubjectUserSid" in src_event else None + src_event["SubjectUserDomain"] if "SubjectUserDomain" in src_event else None ) + self.Sid = src_event["SubjectUserSid"] if "SubjectUserSid" in src_event else None self.LogonId = ( src_event["SubjectLogonId"] if "SubjectLogonId" in src_event else None ) if role == "target" and "TargetUserName" in src_event: self.Name = src_event["TargetUserName"] self.NTDomain = ( - src_event["TargetUserDomain"] - if "TargetUserDomain" in src_event - else None - ) - self.Sid = ( - src_event["TargetUserSid"] if "TargetUserSid" in src_event else None - ) - self.LogonId = ( - src_event["TargetLogonId"] if "TargetLogonId" in src_event else None + src_event["TargetUserDomain"] if "TargetUserDomain" in src_event else None ) + self.Sid = src_event["TargetUserSid"] if "TargetUserSid" in src_event else None + self.LogonId = src_event["TargetLogonId"] if "TargetLogonId" in src_event else None - self.AadTenantId = ( - src_event["AadTenantId"] if "AadTenantId" in src_event else None - ) + self.AadTenantId = src_event["AadTenantId"] if "AadTenantId" in src_event else None self.Sid = src_event["Sid"] if "Sid" in src_event else None self.NTDomain = src_event["NtDomain"] if "NtDomain" in src_event else None self.AadUserId = src_event["AadUserId"] if "AadUserId" in src_event else None diff --git a/msticpy/datamodel/entities/alert.py b/msticpy/datamodel/entities/alert.py index 62752a5d4..6b19b8001 100644 --- a/msticpy/datamodel/entities/alert.py +++ b/msticpy/datamodel/entities/alert.py @@ -4,9 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Alert Entity class.""" + import json +from collections.abc import Mapping from datetime import datetime -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any import pandas as pd @@ -80,18 +82,18 @@ def __init__( kw arguments. """ - self.DisplayName: Optional[str] = None - self.CompromisedEntity: Optional[str] = None + self.DisplayName: str | None = None + self.CompromisedEntity: str | None = None self.Count: Any = None - self.StartTimeUtc: Optional[datetime] = None - self.EndTimeUtc: Optional[datetime] = None + self.StartTimeUtc: datetime | None = None + self.EndTimeUtc: datetime | None = None self.Severity: Any = None - self.SystemAlertId: Optional[str] = None - self.SystemAlertIds: List[str] = [] - self.AlertType: Optional[str] = None - self.VendorName: Optional[str] = None - self.ProviderName: Optional[str] = None - self.Entities: Optional[List] = None + self.SystemAlertId: str | None = None + self.SystemAlertIds: list[str] = [] + self.AlertType: str | None = None + self.VendorName: str | None = None + self.ProviderName: str | None = None + self.Entities: list | None = None self.Version = "3.0" super().__init__(src_entity=src_entity, **kwargs) if src_entity is not None: @@ -100,7 +102,7 @@ def __init__( if isinstance(src_event, pd.Series) and not src_event.empty: self._create_from_event(src_event) - def _create_from_ent(self, src_entity): # noqa: MC0001 + def _create_from_ent(self, src_entity): if "StartTime" in src_entity: self.TimeGeneratedUtc = src_entity["StartTime"] if "TimeGenerated" in src_entity: @@ -143,11 +145,9 @@ def name_str(self) -> str: return f"Alert: {alert_name}" or self.__class__.__name__ @property - def AlertId(self) -> Optional[str]: # noqa: N802 + def AlertId(self) -> str | None: # noqa: N802 """Return the system alert ID.""" - return self.SystemAlertId or ( - self.SystemAlertIds[0] if self.SystemAlertIds else None - ) + return self.SystemAlertId or (self.SystemAlertIds[0] if self.SystemAlertIds else None) @AlertId.setter def AlertId(self, value: str): # noqa: N802 @@ -291,10 +291,10 @@ def _create_entities(self, entities): """Create alert entities from returned dicts.""" new_ents = [] for ent in entities: - if isinstance(ent, Tuple): + if isinstance(ent, tuple): ent_details = ent[1] ent_type = ent[0] - elif isinstance(ent, Dict): + elif isinstance(ent, dict): ent_details = ent ent_type = ent.get("Type", "Unknown") else: @@ -341,14 +341,12 @@ def _extract_entities(ents: list): out_ents.append(_find_original_entity(entity, base_ents)) else: for k, val in entity.items(): - if isinstance(val, (list, dict)): + if isinstance(val, list | dict): if isinstance(val, list): nested_ents = [] for item in val: if isinstance(item, dict) and "$ref" in item: - nested_ents.append( - _find_original_entity(item, base_ents) - ) + nested_ents.append(_find_original_entity(item, base_ents)) entity[k] = nested_ents elif isinstance(val, dict) and "$ref" in val: entity[k] = _find_original_entity(val, base_ents) @@ -365,7 +363,7 @@ def _find_original_entity(ent, base_ents): return ent -def _generate_base_ents(ents: list) -> list: # noqa: MC0001 +def _generate_base_ents(ents: list) -> list: """Generate a list of all enties form a set of nested entities.""" base_ents = [] for ent in ents: @@ -380,10 +378,7 @@ def _generate_base_ents(ents: list) -> list: # noqa: MC0001 for p in prop[val]: if isinstance(p, dict) and "$id" in p.keys(): base_ents.append(p) - elif ( - isinstance(prop[val], dict) - and "$id" in prop[val].keys() - ): + elif isinstance(prop[val], dict) and "$id" in prop[val].keys(): base_ents.append(val) elif isinstance(item, dict) and "$id" in item.keys(): base_ents.append(item) diff --git a/msticpy/datamodel/entities/azure_resource.py b/msticpy/datamodel/entities/azure_resource.py index 8cccc7999..e40bdd453 100644 --- a/msticpy/datamodel/entities/azure_resource.py +++ b/msticpy/datamodel/entities/azure_resource.py @@ -4,9 +4,11 @@ # license information. # -------------------------------------------------------------------------- """AzureResource Entity class.""" + import re +from collections.abc import Mapping from itertools import islice -from typing import Any, Dict, Mapping, Optional +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -52,9 +54,9 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.ResourceId: Optional[str] = None - self.ResourceIdParts: Dict[str, str] = {} - self.Url: Optional[str] = None + self.ResourceId: str | None = None + self.ResourceIdParts: dict[str, str] = {} + self.Url: str | None = None super().__init__(src_entity=src_entity, **kwargs) if self.ResourceId and not self.ResourceIdParts: self._extract_resource_parts() @@ -102,4 +104,4 @@ def _extract_resource_parts(self): res_elems = res_match.groupdict().get("res_path", "").split("/") keys = islice(res_elems, 0, len(res_elems), 2) vals = islice(res_elems, 1, len(res_elems), 2) - self.ResourceIdParts = dict(zip(keys, vals)) + self.ResourceIdParts = dict(zip(keys, vals, strict=False)) diff --git a/msticpy/datamodel/entities/cloud_application.py b/msticpy/datamodel/entities/cloud_application.py index d1aa5fbc6..2bcbde75d 100644 --- a/msticpy/datamodel/entities/cloud_application.py +++ b/msticpy/datamodel/entities/cloud_application.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """CloudApplication Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -53,9 +55,9 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.Name: Optional[str] = None - self.AppId: Optional[str] = None - self.InstanceName: Optional[str] = None + self.Name: str | None = None + self.AppId: str | None = None + self.InstanceName: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/cloud_logon_session.py b/msticpy/datamodel/entities/cloud_logon_session.py index 5e238f12b..5d43a1c55 100644 --- a/msticpy/datamodel/entities/cloud_logon_session.py +++ b/msticpy/datamodel/entities/cloud_logon_session.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """CloudApplication Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -62,10 +64,10 @@ def __init__( kw arguments. """ - self.SessionId: Optional[str] = None - self.Account: Optional[str] = None - self.UserAgent: Optional[str] = None - self.StartTime: Optional[str] = None + self.SessionId: str | None = None + self.Account: str | None = None + self.UserAgent: str | None = None + self.StartTime: str | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event: self._create_from_event(src_event) diff --git a/msticpy/datamodel/entities/dns.py b/msticpy/datamodel/entities/dns.py index 20766599c..5df08615a 100644 --- a/msticpy/datamodel/entities/dns.py +++ b/msticpy/datamodel/entities/dns.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Dns Entity class.""" -from typing import Any, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -56,10 +58,10 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.DomainName: Optional[str] = None - self.IpAddresses: List[IpAddress] = [] - self.DnsServerIp: Optional[IpAddress] = None - self.HostIpAddress: Optional[IpAddress] = None + self.DomainName: str | None = None + self.IpAddresses: list[IpAddress] = [] + self.DnsServerIp: IpAddress | None = None + self.HostIpAddress: IpAddress | None = None super().__init__(src_entity=src_entity, **kwargs) @property @@ -77,7 +79,7 @@ def name_str(self) -> str: "DomainName": None, # IpAddresses (type System.Collections.Generic.List`1 # [Microsoft.Azure.Security.Detection.AlertContracts.V3.Entities.IP]) - "IpAddresses": (List, "IpAddress"), + "IpAddresses": (list, "IpAddress"), # DnsServerIp (type Microsoft.Azure.Security.Detection # .AlertContracts.V3.Entities.IP) "DnsServerIp": "IpAddress", diff --git a/msticpy/datamodel/entities/entity.py b/msticpy/datamodel/entities/entity.py index 355ef0dfd..11d736bdd 100644 --- a/msticpy/datamodel/entities/entity.py +++ b/msticpy/datamodel/entities/entity.py @@ -4,15 +4,17 @@ # license information. # -------------------------------------------------------------------------- """Entity Entity class.""" + from __future__ import annotations import json import pprint import typing from abc import ABC +from collections.abc import Mapping from copy import deepcopy from datetime import datetime, timezone -from typing import Any, Dict, List, Mapping, Optional, Type, Union +from typing import Any import networkx as nx @@ -60,9 +62,9 @@ class Entity(ABC, Node): Implements common methods for Entity classes """ - ENTITY_NAME_MAP: Dict[str, type] = {} - _entity_schema: Dict[str, Any] = {} - ID_PROPERTIES: List[str] = [] + ENTITY_NAME_MAP: dict[str, type] = {} + _entity_schema: dict[str, Any] = {} + ID_PROPERTIES: list[str] = [] JSONEncoder = _EntityJSONEncoder def __init__( @@ -117,7 +119,7 @@ def create( cls, src_entity: Mapping[str, Any] | None = None, **kwargs, - ) -> "Entity": + ) -> Entity: """ Create an entity from a mapping type (e.g. pd.Series) or dict or kwargs. @@ -169,7 +171,7 @@ def _extract_src_entity(self, src_entity: Mapping[str, Any]): if val in ENTITY_ENUMS.values(): self[attr] = val[src_entity[attr]] elif val in ENTITY_ENUMS: - self[attr] = ENTITY_ENUMS[val][src_entity[attr]] + self[attr] = ENTITY_ENUMS[val][src_entity[attr]] # type: ignore[index] continue except KeyError: # Catch key errors from invalid enum values @@ -244,9 +246,7 @@ def __str__(self) -> str: def __repr__(self) -> str: """Return repr of entity.""" - params = ", ".join( - f"{name}={val}" for name, val in self.properties.items() if val - ) + params = ", ".join(f"{name}={val}" for name, val in self.properties.items() if val) if self.edges: params = f"{params}, edges={'. '.join(str(edge) for edge in self.edges)}" @@ -359,7 +359,7 @@ def is_equivalent(self, other: Any) -> bool: if prop not in ("edges", "TimeGenerated") and not prop.startswith("_") ) - def merge(self, other: Any) -> "Entity": + def merge(self, other: Any) -> Entity: """ Merge with other entity to create new entity. @@ -472,8 +472,8 @@ class name string. @classmethod def instantiate_entity( - cls, raw_entity: Mapping[str, Any], entity_type: Optional[Type] = None - ) -> Union["Entity", Mapping[str, Any]]: + cls, raw_entity: Mapping[str, Any], entity_type: type | None = None + ) -> Entity | Mapping[str, Any]: """ Class factory to return entity from raw dictionary representation. @@ -507,7 +507,7 @@ def instantiate_entity( raise TypeError(f"Could not find a suitable type for {entity_type}") @classmethod - def _get_entity_type_name(cls, entity_type: Type) -> str: + def _get_entity_type_name(cls, entity_type: type) -> str: """ Get V3 entity name for an entity. @@ -524,20 +524,14 @@ def _get_entity_type_name(cls, entity_type: Type) -> str: """ try: name = next( - iter( - ( - key - for key, val in cls.ENTITY_NAME_MAP.items() - if val == entity_type - ) - ) + iter((key for key, val in cls.ENTITY_NAME_MAP.items() if val == entity_type)) ) except StopIteration: name = "unknown" return name @property - def node_properties(self) -> Dict[str, Any]: + def node_properties(self) -> dict[str, Any]: """ Return all public properties that are not entities. @@ -550,7 +544,7 @@ def node_properties(self) -> Dict[str, Any]: props = { name: str(value) for name, value in self.properties.items() - if not isinstance(value, (Entity, list)) and name != "edges" + if not isinstance(value, Entity | list) and name != "edges" } props["Description"] = self.description_str props["Name"] = self.name_str @@ -577,9 +571,7 @@ def to_networkx(self, graph: nx.Graph = None) -> nx.Graph: if not graph.has_node(self): graph.add_node(self.name_str, **self.node_properties) for edge in self.edges: - if not isinstance(edge.source, Entity) or not isinstance( - edge.target, Entity - ): + if not isinstance(edge.source, Entity) or not isinstance(edge.target, Entity): continue if graph.has_edge(edge.source.name_str, edge.target.name_str): continue @@ -600,7 +592,7 @@ def to_networkx(self, graph: nx.Graph = None) -> nx.Graph: return graph @classmethod - def get_pivot_list(cls, search_str: Optional[str] = None) -> List[str]: + def get_pivot_list(cls, search_str: str | None = None) -> list[str]: """ Return list of current pivot functions. @@ -664,9 +656,9 @@ def make_pivot_shortcut(cls, func_name: str, target: str, overwrite: bool = Fals """ func_path = func_name.split(".") if "." in func_name else [func_name] - curr_attr: Optional[Any] = cls + curr_attr: Any | None = cls for path in func_path: - curr_attr = getattr(curr_attr, path, None) # type: ignore + curr_attr = getattr(curr_attr, path, None) if not curr_attr: raise AttributeError(f"No function found for {func_name}") if not hasattr(curr_attr, "pivot_properties"): @@ -720,6 +712,6 @@ def del_pivot_shortcut(cls, func_name: str): delattr(cls, func_name) -def camelcase_property_names(input_ent: Dict[str, Any]) -> Dict[str, Any]: +def camelcase_property_names(input_ent: dict[str, Any]) -> dict[str, Any]: """Change initial letter Microsoft Sentinel API entity properties to upper case.""" return {key[0].upper() + key[1:]: input_ent[key] for key in input_ent} diff --git a/msticpy/datamodel/entities/entity_enums.py b/msticpy/datamodel/entities/entity_enums.py index e625ea991..4567fcc7f 100644 --- a/msticpy/datamodel/entities/entity_enums.py +++ b/msticpy/datamodel/entities/entity_enums.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Entity enumerations.""" + from enum import Enum -from typing import Dict, Type from ..._version import VERSION from ...common.utility import export @@ -14,7 +14,7 @@ __author__ = "Ian Hellen" # pylint: disable=invalid-name -ENTITY_ENUMS: Dict[str, Type] = {} +ENTITY_ENUMS: dict[str, type] = {} # pylint: disable=invalid-name diff --git a/msticpy/datamodel/entities/entity_graph.py b/msticpy/datamodel/entities/entity_graph.py index 19481e70d..c0ec08ab2 100644 --- a/msticpy/datamodel/entities/entity_graph.py +++ b/msticpy/datamodel/entities/entity_graph.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Entity Graph classes.""" -from typing import Any, Dict, Optional, Set + +from typing import Any from ..._version import VERSION @@ -17,9 +18,9 @@ class Node: def __init__(self): """Initialize the node.""" - self.edges: Set["Edge"] = set() + self.edges: set[Edge] = set() - def add_edge(self, target: "Node", edge_attrs: Optional[Dict[str, Any]] = None): + def add_edge(self, target: "Node", edge_attrs: dict[str, Any] | None = None): """ Add an edge between self and target. @@ -45,7 +46,7 @@ def has_edge(self, other): class Edge: """Entity edge class.""" - def __init__(self, source: Node, target: Node, attrs: Dict[str, Any] = None): + def __init__(self, source: Node, target: Node, attrs: dict[str, Any] = None): """ Create a new edge between `source` and `target`. @@ -62,7 +63,7 @@ def __init__(self, source: Node, target: Node, attrs: Dict[str, Any] = None): self.source: Node = source self.target: Node = target - self.attrs: Dict[str, Any] = attrs or {} + self.attrs: dict[str, Any] = attrs or {} def add_attr(self, name: str, value: Any): """Add an edge attribute.""" @@ -74,9 +75,7 @@ def __str__(self): def __repr__(self): """Return full repr of edge.""" - other_attrs = [ - f"{name}='{val}'" for name, val in self.attrs.items() if name != "name" - ] + other_attrs = [f"{name}='{val}'" for name, val in self.attrs.items() if name != "name"] if not other_attrs: return f"Edge(name={str(self)})" return f"Edge(name={str(self)}, {', '.join(other_attrs)})" diff --git a/msticpy/datamodel/entities/file.py b/msticpy/datamodel/entities/file.py index 3acf6eae6..a3b7acaa3 100644 --- a/msticpy/datamodel/entities/file.py +++ b/msticpy/datamodel/entities/file.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """File Entity class.""" -from typing import Any, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.data_types import SharedProperty @@ -83,16 +85,16 @@ def __init__( kw arguments. """ - self.FullPath: Optional[str] = None - self.Directory: Optional[str] = None - self.Name: Optional[str] = None - self.Md5: Optional[str] = None - self.Host: Optional[Host] = None - self.Sha1: Optional[str] = None - self.Sha256: Optional[str] = None - self.Sha256Ac: Optional[str] = None - self.FileHashes: List[FileHash] = [] - self.PathSeparator: Optional[str] = "\\" + self.FullPath: str | None = None + self.Directory: str | None = None + self.Name: str | None = None + self.Md5: str | None = None + self.Host: Host | None = None + self.Sha1: str | None = None + self.Sha256: str | None = None + self.Sha256Ac: str | None = None + self.FileHashes: list[FileHash] = [] + self.PathSeparator: str | None = "\\" self.OSFamily = OSFamily.Windows super().__init__(src_entity=src_entity, **kwargs) if src_event is not None: @@ -107,9 +109,7 @@ def __init__( @property def path_separator(self): """Return the path separator used by the file.""" - if ( - self.Directory and "/" in self.Directory - ) or self.OSFamily != OSFamily.Windows: + if (self.Directory and "/" in self.Directory) or self.OSFamily != OSFamily.Windows: return "/" return "\\" @@ -173,7 +173,7 @@ def _add_paths(self, full_path, file_name=None): self.Directory = full_path.split(self.PathSeparator)[:-1] @property - def file_hash(self) -> Optional[str]: + def file_hash(self) -> str | None: """ Return the first defined file hash. diff --git a/msticpy/datamodel/entities/file_hash.py b/msticpy/datamodel/entities/file_hash.py index ac3d63c66..470e1d9c3 100644 --- a/msticpy/datamodel/entities/file_hash.py +++ b/msticpy/datamodel/entities/file_hash.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """FileHash Entity class.""" -from typing import Any, Mapping + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export diff --git a/msticpy/datamodel/entities/geo_location.py b/msticpy/datamodel/entities/geo_location.py index a380de7aa..33b319f7b 100644 --- a/msticpy/datamodel/entities/geo_location.py +++ b/msticpy/datamodel/entities/geo_location.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """GeoLocation Entity class.""" -from typing import Any, Mapping, Optional, Tuple + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -61,13 +63,13 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.CountryCode: Optional[str] = None - self.CountryOrRegionName: Optional[str] = None - self.State: Optional[str] = None - self.City: Optional[str] = None - self.Longitude: Optional[float] = None - self.Latitude: Optional[float] = None - self.Asn: Optional[str] = None + self.CountryCode: str | None = None + self.CountryOrRegionName: str | None = None + self.State: str | None = None + self.City: str | None = None + self.Longitude: float | None = None + self.Latitude: float | None = None + self.Asn: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property @@ -81,7 +83,7 @@ def name_str(self) -> str: return self.CountryCode or self.__class__.__name__ @property - def CountryName(self) -> Optional[str]: # noqa: N802 + def CountryName(self) -> str | None: # noqa: N802 """Return CountryName.""" return self.CountryOrRegionName @@ -91,7 +93,7 @@ def CountryName(self, value: str): # noqa: N802 self.CountryOrRegionName = value @property - def coordinates(self) -> Tuple[float, float]: + def coordinates(self) -> tuple[float, float]: """Return Latitude/Longitude as a tuple of floats.""" if self.Latitude and self.Longitude: return self.Latitude, self.Longitude diff --git a/msticpy/datamodel/entities/graph_property.py b/msticpy/datamodel/entities/graph_property.py index 854305388..b287321f5 100644 --- a/msticpy/datamodel/entities/graph_property.py +++ b/msticpy/datamodel/entities/graph_property.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- """Entity graph property.""" -from typing import Union from ..._version import VERSION @@ -16,9 +15,7 @@ # Future - will replace entity graph creation with property factory -def graph_property( - name: str, prop_type: Union[type, str], edge_name: str = None -) -> property: +def graph_property(name: str, prop_type: type | str, edge_name: str = None) -> property: """Property factory for graph_property.""" storage_name = f"_{name}" edge_attrs = {"name": edge_name or name} diff --git a/msticpy/datamodel/entities/host.py b/msticpy/datamodel/entities/host.py index 5f1f85bef..cf119e013 100644 --- a/msticpy/datamodel/entities/host.py +++ b/msticpy/datamodel/entities/host.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Host Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.data_types import SplitProperty @@ -76,16 +78,16 @@ def __init__( kw arguments. """ - self.DnsDomain: Optional[str] = None - self.NTDomain: Optional[str] = None - self.HostName: Optional[str] = None - self.NetBiosName: Optional[str] = None - self.AzureID: Optional[str] = None - self.OMSAgentID: Optional[str] = None + self.DnsDomain: str | None = None + self.NTDomain: str | None = None + self.HostName: str | None = None + self.NetBiosName: str | None = None + self.AzureID: str | None = None + self.OMSAgentID: str | None = None self.OSFamily: OSFamily = OSFamily.Windows - self.OSVersion: Optional[str] = None + self.OSVersion: str | None = None self.IsDomainJoined: bool = False - self.DeviceId: Optional[str] = None + self.DeviceId: str | None = None super().__init__(src_entity=src_entity, **kwargs) self._computer = None @@ -93,19 +95,19 @@ def __init__( self._create_from_event(src_event) @property - def computer(self) -> Optional[str]: + def computer(self) -> str | None: """Return computer from source event.""" return self._computer if self._computer is not None else self.fqdn @property - def fqdn(self) -> Optional[str]: + def fqdn(self) -> str | None: """Construct FQDN from host + dns.""" if self.DnsDomain: return f"{self.HostName}.{self.DnsDomain}" return self.HostName @property - def FullName(self) -> Optional[str]: # noqa: N802 + def FullName(self) -> str | None: # noqa: N802 """Return the full name of the host - either FQDN or Netbiosname.""" # noqa: N802 if self.DnsDomain: return f"{self.HostName or self.NetBiosName}.{self.DnsDomain}" diff --git a/msticpy/datamodel/entities/host_logon_session.py b/msticpy/datamodel/entities/host_logon_session.py index 512b82788..0c1f06353 100644 --- a/msticpy/datamodel/entities/host_logon_session.py +++ b/msticpy/datamodel/entities/host_logon_session.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """HostLogonSession Entity class.""" + from __future__ import annotations +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -71,10 +73,10 @@ def __init__( """ - self.Account: Optional[Account] = None + self.Account: Account | None = None self.StartTimeUtc: datetime = datetime.min self.EndTimeUtc: datetime = datetime.min - self.Host: Optional[Host] = None + self.Host: Host | None = None self.SessionId: str | None = None super().__init__(src_entity=src_entity, **kwargs) diff --git a/msticpy/datamodel/entities/iot_device.py b/msticpy/datamodel/entities/iot_device.py index b6ecd253f..1efe16d65 100644 --- a/msticpy/datamodel/entities/iot_device.py +++ b/msticpy/datamodel/entities/iot_device.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """IoTDevice Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -75,20 +77,20 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.IoTHub: Optional[str] = None - self.DeviceId: Optional[str] = None - self.DeviceName: Optional[str] = None - self.IoTSecurityAgentId: Optional[str] = None - self.DeviceType: Optional[str] = None - self.Source: Optional[str] = None - self.SourceRef: Optional[str] = None - self.Manufacturer: Optional[str] = None - self.Model: Optional[str] = None - self.OperatingSystem: Optional[str] = None - self.IpAddress: Optional[str] = None - self.MacAddress: Optional[str] = None - self.Protocols: Optional[str] = None - self.SerialNumber: Optional[str] = None + self.IoTHub: str | None = None + self.DeviceId: str | None = None + self.DeviceName: str | None = None + self.IoTSecurityAgentId: str | None = None + self.DeviceType: str | None = None + self.Source: str | None = None + self.SourceRef: str | None = None + self.Manufacturer: str | None = None + self.Model: str | None = None + self.OperatingSystem: str | None = None + self.IpAddress: str | None = None + self.MacAddress: str | None = None + self.Protocols: str | None = None + self.SerialNumber: str | None = None super().__init__(src_entity=src_entity, **kwargs) diff --git a/msticpy/datamodel/entities/ip_address.py b/msticpy/datamodel/entities/ip_address.py index 44468f7b1..be10368a3 100644 --- a/msticpy/datamodel/entities/ip_address.py +++ b/msticpy/datamodel/entities/ip_address.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """IpAddress Entity class.""" + from __future__ import annotations +from collections.abc import Mapping from ipaddress import IPv4Address, IPv6Address, ip_address -from typing import Any, Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -101,9 +103,7 @@ def ip_address(self) -> IPv4Address | IPv6Address | None: def description_str(self) -> str: """Return Entity Description.""" return ( - f"{self.Address} - {self.Location.CountryCode}" - if self.Location - else self.Address + f"{self.Address} - {self.Location.CountryCode}" if self.Location else self.Address ) @property diff --git a/msticpy/datamodel/entities/mail_cluster.py b/msticpy/datamodel/entities/mail_cluster.py index df99f719b..be6c33b2c 100644 --- a/msticpy/datamodel/entities/mail_cluster.py +++ b/msticpy/datamodel/entities/mail_cluster.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """MailCluster Entity class.""" -from typing import Any, Dict, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -85,21 +87,21 @@ def __init__( kw arguments. """ - self.NetworkMessageIds: List[str] = [] - self.CountByDeliveryStatus: Dict[str, int] = {} - self.CountByThreatType: Dict[str, int] = {} - self.CountByProtectionStatus: Dict[str, int] = {} - self.Threats: List[str] = [] - self.Query: Optional[str] = None + self.NetworkMessageIds: list[str] = [] + self.CountByDeliveryStatus: dict[str, int] = {} + self.CountByThreatType: dict[str, int] = {} + self.CountByProtectionStatus: dict[str, int] = {} + self.Threats: list[str] = [] + self.Query: str | None = None self.QueryTime: Any = None self.MailCount: int = 0 self.IsVolumeAnomaly: bool = False - self.Source: Optional[str] = None - self.ClusterSourceIdentifier: Optional[str] = None - self.ClusterSourceType: Optional[str] = None + self.Source: str | None = None + self.ClusterSourceIdentifier: str | None = None + self.ClusterSourceType: str | None = None self.ClusterQueryStartTime: Any = None self.ClusterQueryEndTime: Any = None - self.ClusterGroup: Optional[str] = None + self.ClusterGroup: str | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event is not None: diff --git a/msticpy/datamodel/entities/mail_message.py b/msticpy/datamodel/entities/mail_message.py index f9a4a190e..9642e0d3e 100644 --- a/msticpy/datamodel/entities/mail_message.py +++ b/msticpy/datamodel/entities/mail_message.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """MailMessage Entity class.""" -from typing import Any, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -107,32 +109,32 @@ def __init__( kw arguments. """ - self.Recipient: Optional[str] = None - self.Files: List[Entity] = [] - self.Urls: List[str] = [] - self.Threats: List[str] = [] - self.Sender: Optional[str] = None - self.P1Sender: Optional[str] = None - self.P1SenderDisplayName: Optional[str] = None - self.P1SenderDomain: Optional[str] = None - self.SenderIP: Optional[str] = None - self.P2Sender: Optional[str] = None - self.P2SenderDisplayName: Optional[str] = None - self.P2SenderDomain: Optional[str] = None + self.Recipient: str | None = None + self.Files: list[Entity] = [] + self.Urls: list[str] = [] + self.Threats: list[str] = [] + self.Sender: str | None = None + self.P1Sender: str | None = None + self.P1SenderDisplayName: str | None = None + self.P1SenderDomain: str | None = None + self.SenderIP: str | None = None + self.P2Sender: str | None = None + self.P2SenderDisplayName: str | None = None + self.P2SenderDomain: str | None = None self.ReceivedDate: Any = None - self.NetworkMessageId: Optional[str] = None - self.InternetMessageId: Optional[str] = None - self.Subject: Optional[str] = None - self.BodyFingerprintBin1: Optional[str] = None - self.BodyFingerprintBin2: Optional[str] = None - self.BodyFingerprintBin3: Optional[str] = None - self.BodyFingerprintBin4: Optional[str] = None - self.BodyFingerprintBin5: Optional[str] = None - self.AntispamDirection: Optional[str] = None - self.DeliveryAction: Optional[str] = None - self.DeliveryLocation: Optional[str] = None - self.Language: Optional[str] = None - self.ThreatDetectionMethods: Optional[str] = None + self.NetworkMessageId: str | None = None + self.InternetMessageId: str | None = None + self.Subject: str | None = None + self.BodyFingerprintBin1: str | None = None + self.BodyFingerprintBin2: str | None = None + self.BodyFingerprintBin3: str | None = None + self.BodyFingerprintBin4: str | None = None + self.BodyFingerprintBin5: str | None = None + self.AntispamDirection: str | None = None + self.DeliveryAction: str | None = None + self.DeliveryLocation: str | None = None + self.Language: str | None = None + self.ThreatDetectionMethods: str | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event: @@ -166,11 +168,7 @@ def description_str(self): @property def name_str(self) -> str: """Return Entity Name.""" - return ( - self.Subject - or f"MailMessage to: {self.Recipient}" - or self.__class__.__name__ - ) + return self.Subject or f"MailMessage to: {self.Recipient}" or self.__class__.__name__ _entity_schema = { "Recipient": None, diff --git a/msticpy/datamodel/entities/mailbox.py b/msticpy/datamodel/entities/mailbox.py index 582504073..33d9d6c1a 100644 --- a/msticpy/datamodel/entities/mailbox.py +++ b/msticpy/datamodel/entities/mailbox.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Mailbox Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -65,11 +67,11 @@ def __init__( kw arguments. """ - self.MailboxPrimaryAddress: Optional[str] = None - self.DisplayName: Optional[str] = None - self.Upn: Optional[str] = None - self.ExternalDirectoryObjectId: Optional[str] = None - self.RiskLevel: Optional[str] = None + self.MailboxPrimaryAddress: str | None = None + self.DisplayName: str | None = None + self.Upn: str | None = None + self.ExternalDirectoryObjectId: str | None = None + self.RiskLevel: str | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event: @@ -85,10 +87,7 @@ def _create_from_event(self, src_event): @property def description_str(self): """Return Entity Description.""" - return ( - f"{self.MailboxPrimaryAddress} - {self.RiskLevel}" - or self.__class__.__name__ - ) + return f"{self.MailboxPrimaryAddress} - {self.RiskLevel}" or self.__class__.__name__ @property def name_str(self) -> str: diff --git a/msticpy/datamodel/entities/mailbox_configuration.py b/msticpy/datamodel/entities/mailbox_configuration.py index de71e3006..953151a54 100644 --- a/msticpy/datamodel/entities/mailbox_configuration.py +++ b/msticpy/datamodel/entities/mailbox_configuration.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """MailboxConfiguration Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -70,12 +72,12 @@ def __init__( kw arguments. """ - self.MailboxPrimaryAddress: Optional[str] = None - self.DisplayName: Optional[str] = None - self.Upn: Optional[str] = None - self.ExternalDirectoryObjectId: Optional[str] = None - self.ConfigType: Optional[str] = None - self.ConfigId: Optional[str] = None + self.MailboxPrimaryAddress: str | None = None + self.DisplayName: str | None = None + self.Upn: str | None = None + self.ExternalDirectoryObjectId: str | None = None + self.ConfigType: str | None = None + self.ConfigId: str | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event: @@ -92,9 +94,7 @@ def _create_from_event(self, src_event): @property def description_str(self): """Return Entity Description.""" - return ( - f"{self.MailboxPrimaryAddress} - {self.ConfigId}" or self.__class__.__name__ - ) + return f"{self.MailboxPrimaryAddress} - {self.ConfigId}" or self.__class__.__name__ @property def name_str(self) -> str: diff --git a/msticpy/datamodel/entities/malware.py b/msticpy/datamodel/entities/malware.py index 1defe7750..2f6d5bb35 100644 --- a/msticpy/datamodel/entities/malware.py +++ b/msticpy/datamodel/entities/malware.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Malware Entity class.""" -from typing import Any, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -61,9 +63,9 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): """ self.Name: str = "" self.Category: str = "" - self.File: Optional[File] = None - self.Files: List[File] = [] - self.Processes: List[Process] = [] + self.File: File | None = None + self.Files: list[File] = [] + self.Processes: list[Process] = [] super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/network_connection.py b/msticpy/datamodel/entities/network_connection.py index 723160f5c..8f14ba035 100644 --- a/msticpy/datamodel/entities/network_connection.py +++ b/msticpy/datamodel/entities/network_connection.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """NetworkConnection Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -64,11 +66,11 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.SourceAddress: Optional[IpAddress] = None - self.SourcePort: Optional[int] = None - self.DestinationAddress: Optional[IpAddress] = None - self.DestinationPort: Optional[int] = None - self.Protocol: Optional[str] = None + self.SourceAddress: IpAddress | None = None + self.SourcePort: int | None = None + self.DestinationAddress: IpAddress | None = None + self.DestinationPort: int | None = None + self.Protocol: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/oauth_application.py b/msticpy/datamodel/entities/oauth_application.py index 28d9bd8c3..a86063461 100644 --- a/msticpy/datamodel/entities/oauth_application.py +++ b/msticpy/datamodel/entities/oauth_application.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """OAuthApplication Entity class.""" -from typing import Any, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -65,14 +67,14 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.OAuthAppId: Optional[str] = None - self.OAuthObjectId: Optional[str] = None - self.Name: Optional[str] = None - self.TenantId: Optional[str] = None - self.PublisherName: Optional[str] = None - self.Risk: Optional[str] = None - self.Permissions: List[str] = [] - self.RedirectURLs: List[str] = [] + self.OAuthAppId: str | None = None + self.OAuthObjectId: str | None = None + self.Name: str | None = None + self.TenantId: str | None = None + self.PublisherName: str | None = None + self.Risk: str | None = None + self.Permissions: list[str] = [] + self.RedirectURLs: list[str] = [] self.AuthorizedBy: int = 0 super().__init__(src_entity=src_entity, **kwargs) diff --git a/msticpy/datamodel/entities/process.py b/msticpy/datamodel/entities/process.py index baf177e1c..68f7f6149 100644 --- a/msticpy/datamodel/entities/process.py +++ b/msticpy/datamodel/entities/process.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Process Entity class.""" + +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -84,15 +86,15 @@ def __init__( kw arguments. """ - self.ProcessId: Optional[str] = None - self.CommandLine: Optional[str] = None - self.ElevationToken: Optional[ElevationToken] = None + self.ProcessId: str | None = None + self.CommandLine: str | None = None + self.ElevationToken: ElevationToken | None = None self.CreationTimeUtc: datetime = datetime.min - self.ImageFile: Optional[File] = None - self.Account: Optional[Account] = None - self.ParentProcess: Optional[Process] = None - self.Host: Optional[Host] = None - self.LogonSession: Optional[HostLogonSession] = None + self.ImageFile: File | None = None + self.Account: Account | None = None + self.ParentProcess: Process | None = None + self.Host: Host | None = None + self.LogonSession: HostLogonSession | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event is not None: @@ -130,13 +132,13 @@ def _create_from_event(self, src_event, role): self.ImageFile = File(src_event=src_event, role="parent") @property - def ProcessName(self) -> Optional[str]: # noqa: N802 + def ProcessName(self) -> str | None: # noqa: N802 """Return the name of the process file.""" # noqa: N802 file = self["ImageFile"] return file.Name if file else None @property - def ProcessFilePath(self) -> Optional[str]: # noqa: N802 + def ProcessFilePath(self) -> str | None: # noqa: N802 """Return the name of the process file path.""" # noqa: N802 file = self.ImageFile return file.FullPath if file else None diff --git a/msticpy/datamodel/entities/registry_key.py b/msticpy/datamodel/entities/registry_key.py index 35abe1397..3c687c01a 100644 --- a/msticpy/datamodel/entities/registry_key.py +++ b/msticpy/datamodel/entities/registry_key.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """RegistryValue Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -52,8 +54,8 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.Hive: Optional[RegistryHive] = None - self.Key: Optional[str] = None + self.Hive: RegistryHive | None = None + self.Key: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/registry_value.py b/msticpy/datamodel/entities/registry_value.py index b6fbd5812..04536df8b 100644 --- a/msticpy/datamodel/entities/registry_value.py +++ b/msticpy/datamodel/entities/registry_value.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """RegistryValue Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -56,10 +58,10 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.Key: Optional[RegistryKey] = None - self.Name: Optional[str] = None - self.Value: Optional[str] = None - self.ValueType: Optional[str] = None + self.Key: RegistryKey | None = None + self.Name: str | None = None + self.Value: str | None = None + self.ValueType: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/security_group.py b/msticpy/datamodel/entities/security_group.py index bd6e85736..4be784680 100644 --- a/msticpy/datamodel/entities/security_group.py +++ b/msticpy/datamodel/entities/security_group.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """SecurityGroup Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -53,9 +55,9 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.DistinguishedName: Optional[str] = None - self.SID: Optional[str] = None - self.ObjectGuid: Optional[str] = None + self.DistinguishedName: str | None = None + self.SID: str | None = None + self.ObjectGuid: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/service_principal.py b/msticpy/datamodel/entities/service_principal.py index ad10fac3f..0a6c18372 100644 --- a/msticpy/datamodel/entities/service_principal.py +++ b/msticpy/datamodel/entities/service_principal.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """ServicePrincipal Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -61,12 +63,12 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.ServicePrincipalName: Optional[str] = None - self.ServicePrincipalObjectId: Optional[str] = None - self.AppId: Optional[str] = None - self.AppOwnerTenantId: Optional[str] = None - self.TenantId: Optional[str] = None - self.ServicePrincipalType: Optional[str] = None + self.ServicePrincipalName: str | None = None + self.ServicePrincipalObjectId: str | None = None + self.AppId: str | None = None + self.AppOwnerTenantId: str | None = None + self.TenantId: str | None = None + self.ServicePrincipalType: str | None = None super().__init__(src_entity=src_entity, **kwargs) diff --git a/msticpy/datamodel/entities/submission_mail.py b/msticpy/datamodel/entities/submission_mail.py index dee0739be..abf6a7c26 100644 --- a/msticpy/datamodel/entities/submission_mail.py +++ b/msticpy/datamodel/entities/submission_mail.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Submission mail Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -67,16 +69,16 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): kw arguments. """ - self.SubmissionId: Optional[str] = None + self.SubmissionId: str | None = None self.SubmissionDate: Any = None - self.Submitter: Optional[str] = None - self.NetworkMessageId: Optional[str] = None + self.Submitter: str | None = None + self.NetworkMessageId: str | None = None self.Timestamp: Any = None - self.Recipient: Optional[str] = None - self.Sender: Optional[str] = None - self.SenderIp: Optional[str] = None - self.Subject: Optional[str] = None - self.ReportType: Optional[str] = None + self.Recipient: str | None = None + self.Sender: str | None = None + self.SenderIp: str | None = None + self.Subject: str | None = None + self.ReportType: str | None = None super().__init__(src_entity=src_entity, **kwargs) diff --git a/msticpy/datamodel/entities/threat_intelligence.py b/msticpy/datamodel/entities/threat_intelligence.py index 95dc03314..ee2bed777 100644 --- a/msticpy/datamodel/entities/threat_intelligence.py +++ b/msticpy/datamodel/entities/threat_intelligence.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Threatintelligence Entity class.""" -from typing import Any, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -48,12 +50,12 @@ def __init__(self, src_entity: Mapping[str, Any] = None, **kwargs): :param src_entity: instantiate entity using properties of src entity :param kwargs: key-value pair representation of entity """ - self.ProviderName: Optional[str] = None - self.ThreatType: Optional[str] = None - self.ThreatName: Optional[str] = None - self.Confidence: Optional[str] = None - self.ReportLink: Optional[str] = None - self.ThreatDescription: Optional[str] = None + self.ProviderName: str | None = None + self.ThreatType: str | None = None + self.ThreatName: str | None = None + self.Confidence: str | None = None + self.ReportLink: str | None = None + self.ThreatDescription: str | None = None super().__init__(src_entity=src_entity, **kwargs) @property diff --git a/msticpy/datamodel/entities/unknown_entity.py b/msticpy/datamodel/entities/unknown_entity.py index 29cc12d02..2c7cda8fa 100644 --- a/msticpy/datamodel/entities/unknown_entity.py +++ b/msticpy/datamodel/entities/unknown_entity.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Threatintelligence Entity class.""" -from typing import Any, Dict, Mapping + +from collections.abc import Mapping +from typing import Any from ..._version import VERSION from ...common.utility import export @@ -42,7 +44,7 @@ def name_str(self) -> str: """Return Entity Name.""" return self.__class__.__name__ - _entity_schema: Dict[str, Any] = { + _entity_schema: dict[str, Any] = { "TimeGenerated": None, "StartTime": None, "EndTime": None, diff --git a/msticpy/datamodel/entities/url.py b/msticpy/datamodel/entities/url.py index 030ead0db..ecc597080 100644 --- a/msticpy/datamodel/entities/url.py +++ b/msticpy/datamodel/entities/url.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Url Entity class.""" -from typing import Any, Dict, Mapping, Optional + +from collections.abc import Mapping +from typing import Any from urllib3.exceptions import LocationParseError from urllib3.util import parse_url @@ -62,8 +64,8 @@ def __init__( kw arguments. """ - self.Url: Optional[str] = None - self.DetonationVerdict: Optional[str] = None + self.Url: str | None = None + self.DetonationVerdict: str | None = None super().__init__(src_entity=src_entity, **kwargs) if src_event: self._create_from_event(src_event) @@ -99,7 +101,7 @@ def __getattr__(self, name: str): return val return super().__getattr__(name) - _entity_schema: Dict[str, Any] = { + _entity_schema: dict[str, Any] = { # Url (type System.String) "Url": None, "DetonationVerdict": None, @@ -109,7 +111,7 @@ def __getattr__(self, name: str): } -def _url_components(url: str) -> Dict[str, str]: +def _url_components(url: str) -> dict[str, str]: """Return parsed Url components as dict.""" try: return parse_url(url)._asdict() diff --git a/msticpy/datamodel/pivot.py b/msticpy/datamodel/pivot.py deleted file mode 100644 index 718ff4311..000000000 --- a/msticpy/datamodel/pivot.py +++ /dev/null @@ -1,23 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module pivot.py has moved. - -See :py:mod:`msticpy.init.pivot` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - -WARN_MSSG = ( - "This module has moved to msticpy.init.pivot\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/datamodel/soc/incident.py b/msticpy/datamodel/soc/incident.py index fe43b3a59..5ffb74413 100644 --- a/msticpy/datamodel/soc/incident.py +++ b/msticpy/datamodel/soc/incident.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Incident Entity class.""" -from typing import Any, Dict, List, Mapping, Optional + +from collections.abc import Mapping +from typing import Any import pandas as pd @@ -53,15 +55,15 @@ def __init__( kw arguments. """ - self.DisplayName: Optional[str] = None - self.IncidentID: Optional[str] = None - self.Severity: Optional[str] = None - self.Status: Optional[str] = None - self.Owner: Optional[Dict] = None - self.Classification: Optional[str] = None - self.Labels: Optional[List] = None - self.Alerts: Optional[List] = None - self.Entities: Optional[List] = None + self.DisplayName: str | None = None + self.IncidentID: str | None = None + self.Severity: str | None = None + self.Status: str | None = None + self.Owner: dict | None = None + self.Classification: str | None = None + self.Labels: list | None = None + self.Alerts: list | None = None + self.Entities: list | None = None super().__init__(src_entity=src_entity, **kwargs) diff --git a/msticpy/datamodel/soc/sentinel_alert.py b/msticpy/datamodel/soc/sentinel_alert.py index 64f8fa019..345b98155 100644 --- a/msticpy/datamodel/soc/sentinel_alert.py +++ b/msticpy/datamodel/soc/sentinel_alert.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Sentinel Alert class.""" + import json -from typing import Any, Dict, List +from typing import Any import pandas as pd @@ -26,7 +27,7 @@ "VendorOriginalId", ] -_ID_PROPERTIES: List[str] = [ +_ID_PROPERTIES: list[str] = [ "AzSubscriptionId", "AzResourceId", "WorkspaceId", @@ -41,7 +42,7 @@ "ResourceIdentifiers", ] -_QUERY_PROPERTIES: List[str] = [ +_QUERY_PROPERTIES: list[str] = [ "Query Period", "Trigger Operator", "Trigger Threshold", @@ -79,7 +80,7 @@ def __init__( (the default is None) """ - self._custom_query_params: Dict[str, Any] = {} + self._custom_query_params: dict[str, Any] = {} super().__init__(src_entity, src_event, **kwargs) if ( isinstance( @@ -90,7 +91,7 @@ def __init__( ): self._add_sentinel_items(src_event) self._add_extended_sent_props() - self._ids: Dict[str, str] = {} + self._ids: dict[str, str] = {} if self.__dict__ is not None: for id_property in _ID_PROPERTIES: if id_property in self.properties: @@ -115,7 +116,7 @@ def _add_sentinel_items(self, src_event): self.__dict__.update({feature: src_event.get(feature, "")}) @property - def ids(self) -> Dict[str, str]: + def ids(self) -> dict[str, str]: """Return a collection of Identity properties for the alert.""" return self._ids diff --git a/msticpy/init/azure_ml_tools.py b/msticpy/init/azure_ml_tools.py index 8961a8ea9..2e531321d 100644 --- a/msticpy/init/azure_ml_tools.py +++ b/msticpy/init/azure_ml_tools.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Checker functions for Azure ML notebooks.""" + from __future__ import annotations import logging @@ -14,7 +15,6 @@ from typing import TYPE_CHECKING, Any import yaml -from IPython import get_ipython from IPython.display import HTML, display try: @@ -35,7 +35,7 @@ except ImportError: # pylint: disable=invalid-name get_version = None # type: ignore[assignment] - PackageNotFoundError = Exception # type: ignore[assignment,misc,no-redef] + PackageNotFoundError = Exception # type: ignore[assignment,misc] from .._version import VERSION from ..common.pkg_config import _HOME_PATH, get_config, refresh_config @@ -76,8 +76,7 @@ run without error. """ _AZ_CLI_WIKI_URI = ( - "https://github.com/Azure/Azure-Sentinel-Notebooks/wiki/" - "Caching-credentials-with-Azure-CLI" + "https://github.com/Azure/Azure-Sentinel-Notebooks/wiki/Caching-credentials-with-Azure-CLI" ) _CLI_WIKI_MSSG_GEN = ( f"For more information see " @@ -147,13 +146,10 @@ def check_aml_settings( _disp_html("

Starting AML notebook pre-checks...

") _check_pyspark() if isinstance(min_py_ver, str): - min_py_ver = _get_pkg_version(min_py_ver).release # type: ignore + min_py_ver = _get_pkg_version(min_py_ver).release check_python_ver(min_py_ver=min_py_ver) _check_mp_install(min_mp_ver, mp_release, extras) - if _kql_magic_installed(): - _check_kql_prereqs() - _set_kql_env_vars(extras) _run_user_settings() _set_mpconfig_var() _check_azure_cli_status() @@ -173,16 +169,6 @@ def _check_pyspark() -> None: _disp_html(_PYSPARK_KERNEL_NOT_SUPPORTED.format(nb_uri=AZ_GET_STARTED)) -def _kql_magic_installed() -> bool: - try: - # pylint: disable=import-outside-toplevel, unused-import - from Kqlmagic import kql # noqa: F401 - - return True - except ImportError: - return False - - def check_python_ver(min_py_ver: str | tuple = MIN_PYTHON_VER_DEF) -> None: """ Check the current version of the Python kernel. @@ -291,18 +277,16 @@ def populate_config_to_mp_config(mp_path: str | None) -> str | None: return None # if we found one, use it to populate msticpyconfig.yaml - mp_path = mp_path or str( - (get_aml_user_folder() or Path()).joinpath("msticpyconfig.yaml") - ) + mp_path = mp_path or str((get_aml_user_folder() or Path()).joinpath("msticpyconfig.yaml")) mp_config_convert = MpConfigFile(file=config_json) azs_settings = mp_config_convert.map_json_to_mp_ws() def_azs_settings = next( iter(azs_settings.get("AzureSentinel", {}).get("Workspaces", {}).values()), ) if def_azs_settings: - mp_config_convert.settings["AzureSentinel"]["Workspaces"][ - "Default" - ] = def_azs_settings.copy() + mp_config_convert.settings["AzureSentinel"]["Workspaces"]["Default"] = ( + def_azs_settings.copy() + ) if Path(mp_path).exists(): # If there is an existing file read it in @@ -341,17 +325,6 @@ def _check_mp_install( check_mp_ver(min_msticpy_ver=mp_install_version, extras=extras) -def _set_kql_env_vars(extras: list[str] | None) -> None: - """Set environment variables for Kqlmagic based on MP extras.""" - jp_extended = ("azsentinel", "azuresentinel", "kql") - if extras and any(extra for extra in extras if extra in jp_extended): - os.environ["KQLMAGIC_EXTRAS_REQUIRE"] = "jupyter-extended" - else: - os.environ["KQLMAGIC_EXTRAS_REQUIRE"] = "jupyter-basic" - if is_in_aml(): - os.environ["KQLMAGIC_AZUREML_COMPUTE"] = _get_vm_fqdn() - - def _get_pkg_version(version: str | tuple) -> Version: """ Return comparable package version. @@ -456,10 +429,7 @@ def _get_vm_metadata() -> Mapping[str, Any]: lines = content.strip().split("\n") return { - item[0]: item[1] - for line in lines - for item in [line.split("=", 1)] - if len(item) == 2 + item[0]: item[1] for line in lines for item in [line.split("=", 1)] if len(item) == 2 } @@ -467,83 +437,18 @@ def _get_vm_fqdn() -> str: """Get the FQDN of the host.""" vm_metadata = _get_vm_metadata() if vm_metadata and "instance" in vm_metadata: - return ( - f"https://{vm_metadata.get('instance')}.{vm_metadata.get('domainsuffix')}" - ) + return f"https://{vm_metadata.get('instance')}.{vm_metadata.get('domainsuffix')}" return "" -def _check_kql_prereqs() -> None: - """ - Check and install packages for Kqlmagic/msal_extensions. - - Notes - ----- - Kqlmagic may trigger warnings about a missing PyGObject package - and some system library dependencies. To fix this do the - following:
- From a notebook run: - - %pip uninstall enum34 - !sudo apt-get --yes install libgirepository1.0-dev - !sudo apt-get --yes install gir1.2-secret-1 - %pip install pygobject - - You can also do this from a terminal - but ensure that you've - activated the environment corresponding to the kernel you are - using prior to running the pip commands. - - # Install the libgi dependency - sudo apt install libgirepository1.0-dev - sudo apt install gir1.2-secret-1 - - # activate the environment - # conda activate azureml_py38 - # source ./env_path/scripts/activate - - # Uninstall enum34 - python -m pip uninstall enum34 - # Install pygobject - python -m install pygobject - - """ - if not is_in_aml(): - return - try: - # If this successfully imports, we are ok - # pylint: disable=import-outside-toplevel - import gi - - # pylint: enable=import-outside-toplevel - del gi - except ImportError: - # Check for system packages - ip_shell = get_ipython() - if not ip_shell: - return - apt_list = ip_shell.run_line_magic("sx", "apt list") - apt_list = [apt.split("/", maxsplit=1)[0] for apt in apt_list] - missing_lx_pkg = [ - apt_pkg - for apt_pkg in ("libgirepository1.0-dev", "gir1.2-secret-1") - if apt_pkg not in apt_list - ] - if missing_lx_pkg: - _disp_html( - "Kqlmagic/msal-extensions pre-requisite PyGObject not installed.", - ) - _disp_html( - "To prevent warnings when loading the Kqlmagic data provider," - " Please run the following command:
" - "!conda install --yes -c conda-forge pygobject
", - ) - - def _check_azure_cli_status() -> None: """Check for Azure CLI credentials.""" # import these only if we need them at runtime # pylint: disable=import-outside-toplevel - from ..auth.azure_auth_core import AzureCliStatus, check_cli_credentials + from ..auth.azure_auth_core import ( # noqa: PLC0415 + AzureCliStatus, + check_cli_credentials, + ) if unit_testing(): return diff --git a/msticpy/init/azure_synapse_tools.py b/msticpy/init/azure_synapse_tools.py index 4280f39e8..bdef45119 100644 --- a/msticpy/init/azure_synapse_tools.py +++ b/msticpy/init/azure_synapse_tools.py @@ -4,11 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Help functions for Synapse pipelines notebooks.""" + import logging import os import re from pathlib import Path -from typing import Dict, List, Literal, Optional, Union +from typing import Literal import httpx import jwt @@ -62,8 +63,8 @@ def is_in_synapse(): def init_synapse( identity_type: IdentityType = "service_principal", - storage_svc_name: Optional[str] = None, - tenant_id: Optional[str] = None, + storage_svc_name: str | None = None, + tenant_id: str | None = None, cloud: str = "global", ): """ @@ -152,7 +153,7 @@ def init_synapse( logger.info("Synapse initialization successful") -def current_mounts() -> Dict[str, str]: +def current_mounts() -> dict[str, str]: """Return dictionary of current Synapse mount points.""" return {mnt.mountPoint: mnt.source for mnt in mssparkutils.fs.mounts()} @@ -211,9 +212,7 @@ def mount_container( print(f"path '{mount_path}' already mounted to {existing_mounts[mount_path]}") return existing_mounts[mount_path] == storage_url - return mssparkutils.fs.mount( - storage_url, mount_path, {"linkedService": linked_service} - ) + return mssparkutils.fs.mount(storage_url, mount_path, {"linkedService": linked_service}) # pylint: disable=too-few-public-methods @@ -237,9 +236,7 @@ def __init__(self, **kwargs): self.name: str = kwargs.get("name") self.entry_type: str = kwargs.get("type") self.etag: str = kwargs.get("etag") - self.properties: Dict[str, Union[str, Dict[str, Union[Dict, str]]]] = ( - kwargs.get("properties", {}) - ) + self.properties: dict[str, str | dict[str, dict | str]] = kwargs.get("properties", {}) @property def svc_type(self) -> str: @@ -275,9 +272,7 @@ def azure_name(self): class SynapseName: """Name mapping to default values.""" - storage_account_prefix = ( - "adlsforsentinel" # + last 7 digit of Sentinel workspace id; - ) + storage_account_prefix = "adlsforsentinel" # + last 7 digit of Sentinel workspace id; key_vault_name_prefix = "kvforsentinel" # + last 7 digit of ws I’d; kv_linked_service = "Akvlink" sp_client_id_name = "clientid" @@ -306,13 +301,10 @@ class MPSparkUtils: _DEF_MOUNT_POINT = "msticpy" - def __init__( - self, mount_point: Optional[str] = None, container: Optional[str] = None - ): + def __init__(self, mount_point: str | None = None, container: str | None = None): """Initialize MPSparkUtils class.""" self.linked_services = [ - LinkedService(**props) - for props in _fetch_linked_services(self.workspace_name) + LinkedService(**props) for props in _fetch_linked_services(self.workspace_name) ] self._get_workspace_ids() self.mount_point = mount_point or self._DEF_MOUNT_POINT @@ -324,7 +316,7 @@ def workspace_name(self) -> str: return mssparkutils.env.getWorkspaceName() @property - def fs_mounts(self) -> Dict[str, str]: + def fs_mounts(self) -> dict[str, str]: """Return a dictionary of mount points and targets.""" return {mnt.mountPoint: mnt.source for mnt in mssparkutils.fs.mounts()} @@ -338,7 +330,7 @@ def config_path(self): """Return mount path for MSTICPy config.""" return Path(f"/synfs/{self.job_id}/{self.mount_point}") - def get_service_of_type(self, svc_type: str) -> Optional[LinkedService]: + def get_service_of_type(self, svc_type: str) -> LinkedService | None: """ Return the first linked service of specific `svc_type`. @@ -355,15 +347,11 @@ def get_service_of_type(self, svc_type: str) -> Optional[LinkedService]: """ return next( - iter( - lnk_svc - for lnk_svc in self.linked_services - if lnk_svc.svc_type == svc_type - ), + iter(lnk_svc for lnk_svc in self.linked_services if lnk_svc.svc_type == svc_type), None, ) - def get_all_services_of_type(self, svc_type) -> List[LinkedService]: + def get_all_services_of_type(self, svc_type) -> list[LinkedService]: """ Return list of Linked services of `svc_type`. @@ -379,11 +367,9 @@ def get_all_services_of_type(self, svc_type) -> List[LinkedService]: List of LinkedService instances of type `svc_type`. """ - return [ - lnk_svc for lnk_svc in self.linked_services if lnk_svc.svc_type == svc_type - ] + return [lnk_svc for lnk_svc in self.linked_services if lnk_svc.svc_type == svc_type] - def get_ws_default_storage(self) -> Optional[LinkedService]: + def get_ws_default_storage(self) -> LinkedService | None: """ Return default storage linked service. @@ -405,7 +391,7 @@ def get_ws_default_storage(self) -> Optional[LinkedService]: except StopIteration: return None - def get_service(self, svc_name: str) -> Optional[LinkedService]: + def get_service(self, svc_name: str) -> LinkedService | None: """ Return named linked service. @@ -417,20 +403,14 @@ def get_service(self, svc_name: str) -> Optional[LinkedService]: """ try: return next( - iter( - lnk_svc - for lnk_svc in self.linked_services - if lnk_svc.name == svc_name - ) + iter(lnk_svc for lnk_svc in self.linked_services if lnk_svc.name == svc_name) ) except StopIteration: return None - def get_storage_service( - self, linked_svc_name: Optional[str] = None - ) -> LinkedService: + def get_storage_service(self, linked_svc_name: str | None = None) -> LinkedService: """Return linked storage service (named) or default storage.""" - storage_svc: Optional[LinkedService] = None + storage_svc: LinkedService | None = None if linked_svc_name: storage_svc = self.get_service(svc_name=linked_svc_name) if not storage_svc: @@ -487,7 +467,7 @@ def _fetch_linked_services(ws_name: str): return resp.json().get("value") -def _set_azure_env_creds(mp_spark: MPSparkUtils, tenant_id: Optional[str] = None): +def _set_azure_env_creds(mp_spark: MPSparkUtils, tenant_id: str | None = None): """Publish Service Principal credentials to environment variables.""" os.environ[AzureCredEnvNames.AZURE_TENANT_ID] = tenant_id or mp_spark.tenant_id client_id = mp_spark.get_kv_secret(SynapseName.sp_client_id_name) @@ -515,7 +495,7 @@ def _set_azure_env_creds(mp_spark: MPSparkUtils, tenant_id: Optional[str] = None ) -def _set_msi_client_id(mp_spark: MPSparkUtils, tenant_id: Optional[str] = None): +def _set_msi_client_id(mp_spark: MPSparkUtils, tenant_id: str | None = None): """Publish Service Principal credentials to environment variables.""" os.environ[AzureCredEnvNames.AZURE_TENANT_ID] = tenant_id or mp_spark.tenant_id os.environ[AzureCredEnvNames.AZURE_CLIENT_ID] = mp_spark.application_id diff --git a/msticpy/init/logging.py b/msticpy/init/logging.py index e3b5e88e4..ed48d5b12 100644 --- a/msticpy/init/logging.py +++ b/msticpy/init/logging.py @@ -4,10 +4,11 @@ # license information. # -------------------------------------------------------------------------- """Logging global config.""" + import logging import os import sys -from typing import NamedTuple, Optional, Union +from typing import NamedTuple from .._version import VERSION from ..common.pkg_config import get_config @@ -22,11 +23,11 @@ class LoggingConfig(NamedTuple): """Logging configuration tuple.""" - log_file: Optional[str] = None + log_file: str | None = None log_level: int = logging.WARNING -def set_logging_level(log_level: Union[int, str]): +def set_logging_level(log_level: int | str): """ Set global logging level. diff --git a/msticpy/init/mp_pandas_accessors.py b/msticpy/init/mp_pandas_accessors.py index e50037cbe..1f77ec764 100644 --- a/msticpy/init/mp_pandas_accessors.py +++ b/msticpy/init/mp_pandas_accessors.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """MSTICPy core pandas accessor methods.""" -from typing import Any, Dict, List, Mapping, Union + +from collections.abc import Mapping +from typing import Any import pandas as pd @@ -84,7 +86,7 @@ def b64extract(self, column: str, **kwargs) -> pd.DataFrame: """ return unpack_df(data=self._df, column=column, **kwargs) - def ioc_extract(self, columns: List[str], **kwargs) -> pd.DataFrame: + def ioc_extract(self, columns: list[str], **kwargs) -> pd.DataFrame: """ Extract IoCs from either a pandas DataFrame. @@ -131,7 +133,7 @@ def ioc_extract(self, columns: List[str], **kwargs) -> pd.DataFrame: def build_process_tree( self, - schema: Union[ProcSchema, Dict[str, Any]] = None, + schema: ProcSchema | dict[str, Any] | None = None, show_summary: bool = False, debug: bool = False, ) -> pd.DataFrame: @@ -289,7 +291,7 @@ def view(self, **kwargs): if self._data_viewer_class is None: try: # pylint: disable=import-outside-toplevel - from ..vis.data_viewer_panel import DataViewer + from ..vis.data_viewer_panel import DataViewer # noqa: PLC0415 except ImportError: print("This component needs the panel package.") return self._df diff --git a/msticpy/init/mp_plugins.py b/msticpy/init/mp_plugins.py index a478b3365..a1083b1e2 100644 --- a/msticpy/init/mp_plugins.py +++ b/msticpy/init/mp_plugins.py @@ -22,14 +22,14 @@ """ - import contextlib import sys +from collections.abc import Iterable from importlib import import_module from inspect import getmembers, isabstract, isclass from pathlib import Path from types import ModuleType -from typing import Iterable, NamedTuple, Optional, Union +from typing import NamedTuple from warnings import warn from .._version import VERSION @@ -50,8 +50,8 @@ class PluginReg(NamedTuple): """Plugin registration tuple.""" - reg_dest: Union[type, ModuleType] # class or module containing CUSTOM_PROVIDERS - name_property: Optional[str] # Custom name(s) for provider + reg_dest: type | ModuleType # class or module containing CUSTOM_PROVIDERS + name_property: str | None # Custom name(s) for provider # This dictionary maps the class of the plugin to @@ -67,7 +67,7 @@ class PluginReg(NamedTuple): } -def read_plugins(plugin_paths: Union[str, Iterable[str]]): +def read_plugins(plugin_paths: str | Iterable[str]): """Load plugins from folders specified in msticpyconfig.yaml.""" plugin_config = [plugin_paths] if isinstance(plugin_paths, str) else plugin_paths if not plugin_config: @@ -79,14 +79,17 @@ def read_plugins(plugin_paths: Union[str, Iterable[str]]): load_plugins_from_path(plugin_path=plugin_path) -def load_plugins_from_path(plugin_path: Union[str, Path]): +def load_plugins_from_path(plugin_path: str | Path): """Load all compatible plugins found in plugin_path.""" sys.path.append(str(plugin_path)) for module_file in Path(plugin_path).glob("*.py"): try: module = import_module(module_file.stem) except ImportError: - warn(f"Unable to import plugin {module_file} from {plugin_path}") + warn( + f"Unable to import plugin {module_file} from {plugin_path}", + stacklevel=2, + ) for name, obj in getmembers(module, isclass): if not isinstance(obj, type): continue @@ -97,7 +100,7 @@ def load_plugins_from_path(plugin_path: Union[str, Path]): # if no specified registration, use the root class reg_dest = reg_object.reg_dest or plugin_type plugin_names = getattr(obj, reg_object.name_property, name) - if not isinstance(plugin_names, (list, tuple)): + if not isinstance(plugin_names, list | tuple): plugin_names = (plugin_names,) for plugin_name in plugin_names: reg_dest.CUSTOM_PROVIDERS[plugin_name] = obj diff --git a/msticpy/init/mp_user_session.py b/msticpy/init/mp_user_session.py index 511a2e217..c5caf2c4d 100644 --- a/msticpy/init/mp_user_session.py +++ b/msticpy/init/mp_user_session.py @@ -134,15 +134,9 @@ def _load_query_providers(user_config, namespace): The namespace to load the component instances into. """ - logger.info( - "Loading %d query providers", len(user_config.get("QueryProviders", {})) - ) - for qry_prov_name, qry_prov_settings in user_config.get( - "QueryProviders", {} - ).items(): - qry_prov = _initialize_component( - qry_prov_name, qry_prov_settings, QueryProvider - ) + logger.info("Loading %d query providers", len(user_config.get("QueryProviders", {}))) + for qry_prov_name, qry_prov_settings in user_config.get("QueryProviders", {}).items(): + qry_prov = _initialize_component(qry_prov_name, qry_prov_settings, QueryProvider) if qry_prov: namespace[qry_prov_name] = qry_prov diff --git a/msticpy/init/nbinit.py b/msticpy/init/nbinit.py index 1b4259878..fd091c534 100644 --- a/msticpy/init/nbinit.py +++ b/msticpy/init/nbinit.py @@ -50,6 +50,7 @@ https://github.com/Azure/Azure-Sentinel-Notebooks/blob/master/ConfiguringNotebookEnvironment.ipynb """ + from __future__ import annotations import importlib @@ -197,51 +198,40 @@ def _verbose(verbosity: int | None = None) -> int: # pylint: disable=use-dict-literal _NB_IMPORTS = [ - dict(pkg="pandas", alias="pd"), - dict(pkg="IPython", tgt="get_ipython"), - dict(pkg="IPython.display", tgt="display"), - dict(pkg="IPython.display", tgt="HTML"), - dict(pkg="IPython.display", tgt="Markdown"), + {"pkg": "pandas", "alias": "pd"}, + {"pkg": "IPython", "tgt": "get_ipython"}, + {"pkg": "IPython.display", "tgt": "display"}, + {"pkg": "IPython.display", "tgt": "HTML"}, + {"pkg": "IPython.display", "tgt": "Markdown"}, # dict(pkg="ipywidgets", alias="widgets"), - dict(pkg="pathlib", tgt="Path"), - dict(pkg="numpy", alias="np"), + {"pkg": "pathlib", "tgt": "Path"}, + {"pkg": "numpy", "alias": "np"}, ] if sns is not None: - _NB_IMPORTS.append(dict(pkg="seaborn", alias="sns")) + _NB_IMPORTS.append({"pkg": "seaborn", "alias": "sns"}) _MP_IMPORTS = [ - dict(pkg="msticpy"), - dict(pkg="msticpy.data", tgt="QueryProvider"), - # dict(pkg="msticpy.vis.foliummap", tgt="FoliumMap"), - # dict(pkg="msticpy.context", tgt="TILookup"), - # dict(pkg="msticpy.context", tgt="GeoLiteLookup"), - # dict(pkg="msticpy.context", tgt="IPStackLookup"), - # dict(pkg="msticpy.transform", tgt="IoCExtract"), - dict(pkg="msticpy.common.utility", tgt="md"), - dict(pkg="msticpy.common.utility", tgt="md_warn"), - dict(pkg="msticpy.common.wsconfig", tgt="WorkspaceConfig"), - dict(pkg="msticpy.init.pivot", tgt="Pivot"), - dict(pkg="msticpy.datamodel", tgt="entities"), - dict(pkg="msticpy.init", tgt="nbmagics"), - # dict(pkg="msticpy.nbtools", tgt="SecurityAlert"), - dict(pkg="msticpy.vis", tgt="mp_pandas_plot"), - # dict(pkg="msticpy.vis", tgt="nbdisplay"), - dict(pkg="msticpy.init", tgt="mp_pandas_accessors"), - # dict(pkg="msticpy", tgt="nbwidgets"), + {"pkg": "msticpy"}, + {"pkg": "msticpy.data", "tgt": "QueryProvider"}, + {"pkg": "msticpy.common.utility", "tgt": "md"}, + {"pkg": "msticpy.common.utility", "tgt": "md_warn"}, + {"pkg": "msticpy.common.wsconfig", "tgt": "WorkspaceConfig"}, + {"pkg": "msticpy.init.pivot", "tgt": "Pivot"}, + {"pkg": "msticpy.datamodel", "tgt": "entities"}, + {"pkg": "msticpy.init", "tgt": "nbmagics"}, + {"pkg": "msticpy.vis", "tgt": "mp_pandas_plot"}, + {"pkg": "msticpy.init", "tgt": "mp_pandas_accessors"}, ] _MP_IMPORT_ALL: list[dict[str, str]] = [ - dict(module_name="msticpy.datamodel.entities"), + {"module_name": "msticpy.datamodel.entities"}, ] # pylint: enable=use-dict-literal -_CONF_URI = ( - "https://msticpy.readthedocs.io/en/latest/getting_started/msticpyconfig.html" -) +_CONF_URI = "https://msticpy.readthedocs.io/en/latest/getting_started/msticpyconfig.html" _AZNB_GUIDE = ( - "Please run the Getting Started Guide for Azure Sentinel " - "ML Notebooks notebook." + "Please run the Getting Started Guide for Azure Sentinel ML Notebooks notebook." ) current_providers: dict[str, Any] = {} # pylint: disable=invalid-name @@ -410,9 +400,7 @@ def init_notebook( _check_msticpy_version() if _detect_env("synapse", **kwargs) and is_in_synapse(): - synapse_params = { - key: val for key, val in kwargs.items() if key in _SYNAPSE_KWARGS - } + synapse_params = {key: val for key, val in kwargs.items() if key in _SYNAPSE_KWARGS} try: init_synapse(**synapse_params) except Exception as err: # pylint: disable=broad-except @@ -492,7 +480,7 @@ def _err_output(*args): def _load_user_defaults(namespace): """Load user defaults, if defined.""" - global current_providers # pylint: disable=global-statement, invalid-name + global current_providers # pylint: disable=global-statement, invalid-name # noqa: PLW0603 stdout_cap = io.StringIO() with redirect_stdout(stdout_cap): _pr_output("Loading user defaults.") @@ -567,8 +555,7 @@ def _show_init_warnings(imp_ok, conf_ok): if not conf_ok: md("One or more configuration items were missing or set incorrectly.") md( - _AZNB_GUIDE - + f" and the
msticpy configuration guide.", + _AZNB_GUIDE + f" and the msticpy configuration guide.", ) md("This notebook may still run but with reduced functionality.") return False @@ -793,13 +780,6 @@ def _set_nb_options(namespace): pd.set_option("display.max_columns", 50) pd.set_option("display.max_colwidth", 100) - os.environ["KQLMAGIC_LOAD_MODE"] = "silent" - # Kqlmagic config will use AZ CLI login if available - kql_config = os.environ.get("KQLMAGIC_CONFIGURATION", "") - if "try_azcli_login" not in kql_config: - kql_config = ";".join([kql_config, "try_azcli_login=True"]) - os.environ["KQLMAGIC_CONFIGURATION"] = kql_config - def _load_pivots(namespace): """Load pivot functions.""" @@ -809,7 +789,7 @@ def _load_pivots(namespace): pivot.reload_pivots() namespace["pivot"] = pivot # pylint: disable=import-outside-toplevel, cyclic-import - import msticpy + import msticpy # noqa: PLC0415 msticpy.pivot = pivot @@ -920,9 +900,7 @@ def _check_and_reload_pkg( if pkg_version < required_version: _err_output(_MISSING_PKG_WARN.format(package=pkg_name)) # sourcery skip: swap-if-expression - resp = ( - input("Install the package now? (y/n)") if not unit_testing() else "y" - ) # nosec + resp = input("Install the package now? (y/n)") if not unit_testing() else "y" # nosec if resp.casefold().startswith("y"): warn_mssg.append(f"{pkg_name} was installed or upgraded.") pkg_spec = f"{pkg_name}>={required_version}" diff --git a/msticpy/init/nbmagics.py b/msticpy/init/nbmagics.py index e41b3b55b..e7250e688 100644 --- a/msticpy/init/nbmagics.py +++ b/msticpy/init/nbmagics.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """msticpy IPython magics.""" + import re -from typing import List, Tuple # pylint: enable=unused-import from IPython import get_ipython @@ -48,9 +48,7 @@ class Base64Magic(Magics): @line_cell_magic @magic_arguments.magic_arguments() - @magic_arguments.argument( - "--out", "-o", help="The variable to return the results in" - ) + @magic_arguments.argument("--out", "-o", help="The variable to return the results in") @magic_arguments.argument( "--pretty", "-p", @@ -125,15 +123,13 @@ def __init__(self, shell): @line_cell_magic @magic_arguments.magic_arguments() - @magic_arguments.argument( - "--out", "-o", help="The variable to return the results in" - ) + @magic_arguments.argument("--out", "-o", help="The variable to return the results in") @magic_arguments.argument( "--ioc_types", "-i", help="The types of IoC to search for (comma-separated string)", ) - def ioc(self, line="", cell=None) -> List[Tuple[str, List[str]]]: + def ioc(self, line="", cell=None) -> list[tuple[str, list[str]]]: """ Ioc Extract IPython magic extension. diff --git a/msticpy/init/pivot.py b/msticpy/init/pivot.py index 9ebf26ad4..e85368647 100644 --- a/msticpy/init/pivot.py +++ b/msticpy/init/pivot.py @@ -6,24 +6,19 @@ """Pivot functions main module.""" import contextlib -import warnings +from collections.abc import Callable, Iterable from datetime import datetime, timedelta, timezone from importlib import import_module from pathlib import Path from types import ModuleType -from typing import Any, Callable, Dict, Iterable, Optional, Type +from typing import Any from .._version import VERSION from ..common.timespan import TimeSpan +from ..common.utility.types import SingletonClass from ..context.tilookup import TILookup from ..data.core.data_providers import QueryProvider from ..datamodel import entities - -with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=DeprecationWarning) - from ..datamodel import pivot as legacy_pivot - -from ..common.utility.types import SingletonClass from ..nbwidgets.query_time import QueryTime from . import pivot_init @@ -53,9 +48,9 @@ class Pivot: def __init__( self, - namespace: Dict[str, Any] = None, + namespace: dict[str, Any] = None, providers: Iterable[Any] = None, - timespan: Optional[TimeSpan] = None, + timespan: TimeSpan | None = None, ): """ Instantiate a Pivot environment. @@ -80,13 +75,13 @@ def __init__( self.timespan = timespan # acquire current providers - self._providers: Dict[str, Any] = {} + self._providers: dict[str, Any] = {} self._param_providers = providers self._param_namespace = namespace def reload_pivots( self, - namespace: Dict[str, Any] = None, + namespace: dict[str, Any] = None, providers: Iterable[Any] = None, clear_existing: bool = False, ): @@ -132,7 +127,7 @@ def reload_pivots( def _get_all_providers( self, - namespace: Dict[str, Any] = None, + namespace: dict[str, Any] = None, providers: Iterable[Any] = None, ): self._providers["TILookup"] = ( @@ -166,8 +161,8 @@ def add_query_provider(self, prov: QueryProvider): @staticmethod def _get_provider_by_type( - provider_type: Type, - namespace: Dict[str, Any] = None, + provider_type: type, + namespace: dict[str, Any] = None, providers: Iterable[Any] = None, ) -> Any: if providers: @@ -186,14 +181,16 @@ def _get_provider_by_type( def _get_def_pivot_reg(): try: # pylint: disable=import-outside-toplevel - from importlib.resources import files # type: ignore[attr-defined] + from importlib.resources import ( # noqa: PLC0415 + files, + ) return files("msticpy").joinpath(_DEF_PIVOT_REG_FILE) except ImportError: return Path(__file__).parent.parent.joinpath(_DEF_PIVOT_REG_FILE) @property - def providers(self) -> Dict[str, Any]: + def providers(self) -> dict[str, Any]: """ Return the current set of loaded providers. @@ -223,7 +220,7 @@ def get_provider(self, name: str) -> Any: """ return self._providers.get(name) - def edit_query_time(self, timespan: Optional[TimeSpan] = None): + def edit_query_time(self, timespan: TimeSpan | None = None): """ Display a QueryTime widget to get the timespan. @@ -291,7 +288,7 @@ def timespan(self, value: Any): return self._query_time.set_time(timespan) - def set_timespan(self, value: Optional[Any] = None, **kwargs): + def set_timespan(self, value: Any | None = None, **kwargs): """ Set the pivot timespan. @@ -330,7 +327,7 @@ def reset_timespan(self): @staticmethod def register_pivot_providers( pivot_reg_path: str, - namespace: Dict[str, Any] = None, + namespace: dict[str, Any] = None, def_container: str = "custom", force_container: bool = False, ): @@ -366,7 +363,7 @@ def register_pivot_providers( def add_pivot_function( func: Callable[[Any], Any], pivot_reg: "PivotRegistration" = None, - container: Optional[str] = None, + container: str | None = None, **kwargs, ): """ @@ -442,8 +439,3 @@ def remove_pivot_funcs(entity: str): def browse(): """Return PivotBrowser.""" return PivotBrowser() - - -# add link in datamodel for legacy location -setattr(legacy_pivot, "Pivot", Pivot) -setattr(legacy_pivot, "PivotRegistration", PivotRegistration) diff --git a/msticpy/init/pivot_core/pivot_browser.py b/msticpy/init/pivot_core/pivot_browser.py index 2c8b6329e..f61d29a91 100644 --- a/msticpy/init/pivot_core/pivot_browser.py +++ b/msticpy/init/pivot_core/pivot_browser.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- """Pivot browser widget.""" -from typing import Dict, List import ipywidgets as widgets from IPython import get_ipython @@ -78,13 +77,13 @@ class PivotBrowser: def __init__(self): """Create an instance of the Pivot browser.""" - self._text: Dict[str, widgets.Widget] = {} - self._select: Dict[str, widgets.Widget] = {} - self._layout: Dict[str, widgets.Widget] = {} - self._html: Dict[str, widgets.Widget] = {} - self._btn: Dict[str, widgets.Widget] = {} + self._text: dict[str, widgets.Widget] = {} + self._select: dict[str, widgets.Widget] = {} + self._layout: dict[str, widgets.Widget] = {} + self._html: dict[str, widgets.Widget] = {} + self._btn: dict[str, widgets.Widget] = {} - self.piv_entities: Dict[str, List[str]] = _get_entities_with_pivots() + self.piv_entities: dict[str, list[str]] = _get_entities_with_pivots() self._create_select_controls() self._create_help_controls() diff --git a/msticpy/init/pivot_core/pivot_container.py b/msticpy/init/pivot_core/pivot_container.py index 0c18e51ef..f026d1b34 100644 --- a/msticpy/init/pivot_core/pivot_container.py +++ b/msticpy/init/pivot_core/pivot_container.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Pivot function hierarchy attribute class.""" + from ..._version import VERSION from ...common.data_types import ObjectContainer diff --git a/msticpy/init/pivot_core/pivot_magic_core.py b/msticpy/init/pivot_core/pivot_magic_core.py index d95bd6b0c..f7abb95ea 100644 --- a/msticpy/init/pivot_core/pivot_magic_core.py +++ b/msticpy/init/pivot_core/pivot_magic_core.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Txt2df core code.""" + from __future__ import annotations import argparse @@ -81,12 +82,12 @@ def run_txt2df(line: str, cell: str, local_ns: dict | None) -> pd.DataFrame: return pd.DataFrame() cell_text = io.StringIO(cell) warn_args: dict[str, str | bool] - if _PD_VER < Version("1.3.0"): # type: ignore + if _PD_VER < Version("1.3.0"): warn_args = {"warn_bad_lines": True} else: warn_args = {"on_bad_lines": "warn"} try: - parsed_df = pd.read_csv( # type: ignore + parsed_df = pd.read_csv( cell_text, header=0 if args.headers else None, sep=args.sep, @@ -98,7 +99,7 @@ def run_txt2df(line: str, cell: str, local_ns: dict | None) -> pd.DataFrame: except ParserError: # try again without headers cell_text = io.StringIO(cell) - parsed_df = pd.read_csv( # type: ignore + parsed_df = pd.read_csv( cell_text, sep=args.sep, skipinitialspace=True, diff --git a/msticpy/init/pivot_core/pivot_pd_accessor.py b/msticpy/init/pivot_core/pivot_pd_accessor.py index becae8437..47d330bb5 100644 --- a/msticpy/init/pivot_core/pivot_pd_accessor.py +++ b/msticpy/init/pivot_core/pivot_pd_accessor.py @@ -5,21 +5,23 @@ # -------------------------------------------------------------------------- """Pandas DataFrame accessor for Pivot functions.""" +from __future__ import annotations + import contextlib import json import re import warnings +from collections.abc import Callable, Iterable from datetime import datetime from json import JSONDecodeError from numbers import Number -from typing import Callable, Dict, Iterable, Set, Union import numpy as np import pandas as pd from IPython.core.display import HTML from IPython.core.getipython import get_ipython from IPython.display import display -from packaging.version import parse as parse_version +from packaging.version import parse as parse_version # pylint: disable=no-name-in-module from ..._version import VERSION @@ -161,7 +163,9 @@ def tee(self, var_name: str, clobber: bool = False) -> pd.DataFrame: """ if self._ip and var_name: if var_name in self._ip.ns_table["user_local"] and not clobber: - warnings.warn(f"Did not overwrite existing {var_name} in namespace") + warnings.warn( + f"Did not overwrite existing {var_name} in namespace", stacklevel=2 + ) else: self._ip.ns_table["user_local"][var_name] = self._df return self._df @@ -208,7 +212,7 @@ def tee_exec(self, df_func: str, *args, **kwargs) -> pd.DataFrame: def filter_cols( self, - cols: Union[str, Iterable[str]], + cols: str | Iterable[str], match_case: bool = False, sort_cols: bool = False, ) -> pd.DataFrame: @@ -235,7 +239,7 @@ def filter_cols( """ curr_cols = self._df.columns - filt_cols: Set[str] = set() + filt_cols: set[str] = set() if isinstance(cols, str): filt_cols.update(_name_match(curr_cols, cols, match_case)) elif isinstance(cols, list): @@ -253,7 +257,7 @@ def filter_cols( def filter( self, - expr: Union[str, Number], + expr: str | Number, match_case: bool = False, numeric_col: bool = False, ) -> pd.DataFrame: @@ -291,9 +295,7 @@ def filter( text_cols = self._df.select_dtypes(include=[object, "string"]) return self._df[ text_cols.apply( - lambda col: col.str.contains( - expr, regex=True, case=match_case, na=False - ) + lambda col: col.str.contains(expr, regex=True, case=match_case, na=False) ).any(axis=1) ] if isinstance(expr, Number) or numeric_col: @@ -307,7 +309,7 @@ def filter( raise TypeError("expr '{expr}' must be a string or numeric type.") def sort( - self, cols: Union[str, Iterable[str], Dict[str, str]], ascending: bool = None + self, cols: str | Iterable[str] | dict[str, str], ascending: bool = None ) -> pd.DataFrame: """ Sort output by column expression. @@ -368,13 +370,11 @@ def sort( continue # look for regex matches for col name df_match_cols = [ - df_cols[s_col] - for s_col in df_cols - if re.match(col, s_col, re.IGNORECASE) + df_cols[s_col] for s_col in df_cols if re.match(col, s_col, re.IGNORECASE) ] # we might get multiple matches if df_match_cols: - sort_cols.update({df_col: col_dict[col] for df_col in df_match_cols}) + sort_cols.update(dict.fromkeys(df_match_cols, col_dict[col])) continue raise ValueError( f"'{col}' column in sort list did not match any columns in input data." @@ -383,7 +383,7 @@ def sort( asc_param = ascending if ascending is not None else list(sort_cols.values()) return self._df.sort_values(list(sort_cols.keys()), ascending=asc_param) - def list_to_rows(self, cols: Union[str, Iterable[str]]) -> pd.DataFrame: + def list_to_rows(self, cols: str | Iterable[str]) -> pd.DataFrame: """ Expand a list column to individual rows. @@ -418,7 +418,7 @@ def list_to_rows(self, cols: Union[str, Iterable[str]]) -> pd.DataFrame: ) return data - def parse_json(self, cols: Union[str, Iterable[str]]) -> pd.DataFrame: + def parse_json(self, cols: str | Iterable[str]) -> pd.DataFrame: """ Convert JSON string columns to Python types. @@ -490,7 +490,7 @@ def _json_safe_conv(val): return val -def _extract_values(data: Union[dict, list, str], key_name: str = "") -> dict: +def _extract_values(data: dict | list | str, key_name: str = "") -> dict: """ Recursively extracts column values from the given key's values. diff --git a/msticpy/init/pivot_core/pivot_pipeline.py b/msticpy/init/pivot_core/pivot_pipeline.py index 6e7a130d4..310c9fbed 100644 --- a/msticpy/init/pivot_core/pivot_pipeline.py +++ b/msticpy/init/pivot_core/pivot_pipeline.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Pivot pipeline class.""" + from collections import namedtuple -from typing import Any, Dict, Iterable, List, Optional +from collections.abc import Iterable +from typing import Any import attr import pandas as pd @@ -39,11 +41,11 @@ class PipelineStep: name: str step_type: str = attr.ib(validator=attr.validators.in_(_STEP_TYPES)) - function: Optional[str] = None - entity: Optional[str] = None - comment: Optional[str] = None - pos_params: List[str] = Factory(list) - params: Dict[str, Any] = Factory(dict) + function: str | None = None + entity: str | None = None + comment: str | None = None + pos_params: list[str] = Factory(list) + params: dict[str, Any] = Factory(dict) def get_exec_step(self) -> PipelineExecStep: """ @@ -95,8 +97,7 @@ def get_exec_step(self) -> PipelineExecStep: def _get_param_string(self) -> str: """Return text representation of keyword params.""" pos_params = [ - f"'{param}'" if isinstance(param, str) else str(param) - for param in self.pos_params + f"'{param}'" if isinstance(param, str) else str(param) for param in self.pos_params ] params_str = [ f"{p_name}='{p_val}'" @@ -145,8 +146,8 @@ class Pipeline: def __init__( self, name: str, - description: Optional[str] = None, - steps: Optional[Iterable[PipelineStep]] = None, + description: str | None = None, + steps: Iterable[PipelineStep] | None = None, ): """ Create Pipeline instance. @@ -163,7 +164,7 @@ def __init__( """ self.name = name self.description = description - self.steps: List[PipelineStep] = [] + self.steps: list[PipelineStep] = [] if steps: self.steps.extend(iter(steps)) @@ -184,7 +185,7 @@ def __repr__(self) -> str: ) @classmethod - def parse_pipeline(cls, pipeline: Dict[str, Dict[str, Any]]) -> "Pipeline": + def parse_pipeline(cls, pipeline: dict[str, dict[str, Any]]) -> "Pipeline": """ Parse single pipeline from dictionary. @@ -208,13 +209,11 @@ def parse_pipeline(cls, pipeline: Dict[str, Dict[str, Any]]) -> "Pipeline": pl_name, pl_dict = next(iter(pipeline.items())) if pl_dict and isinstance(pl_dict, dict): steps = [PipelineStep(**step) for step in pl_dict.get("steps", [])] - return cls( - name=pl_name, description=pl_dict.get("description"), steps=steps - ) + return cls(name=pl_name, description=pl_dict.get("description"), steps=steps) raise ValueError("Dictionary could not be parsed.") @staticmethod - def parse_pipelines(pipelines: Dict[str, Dict[str, Any]]) -> Iterable["Pipeline"]: + def parse_pipelines(pipelines: dict[str, dict[str, Any]]) -> Iterable["Pipeline"]: """ Parse dict of pipelines. @@ -264,9 +263,7 @@ def to_yaml(self) -> str: steps = [attr.asdict(step) for step in self.steps] return yaml.dump({self.name: {"description": self.description, "steps": steps}}) - def run( - self, data: pd.DataFrame, verbose: bool = True, debug: bool = False - ) -> Optional[Any]: + def run(self, data: pd.DataFrame, verbose: bool = True, debug: bool = False) -> Any | None: """ Run the pipeline on the supplied DataFrame. @@ -302,9 +299,7 @@ def run( else: exec_kws = {} func = _get_pd_accessor_func(pipeline_result, exec_action.accessor) - pipeline_result = func( - *exec_action.pos_params, **exec_action.params, **exec_kws - ) + pipeline_result = func(*exec_action.pos_params, **exec_action.params, **exec_kws) return pipeline_result diff --git a/msticpy/init/pivot_core/pivot_register.py b/msticpy/init/pivot_core/pivot_register.py index 9b3b8939b..736d0a1bb 100644 --- a/msticpy/init/pivot_core/pivot_register.py +++ b/msticpy/init/pivot_core/pivot_register.py @@ -4,12 +4,14 @@ # license information. # -------------------------------------------------------------------------- """Pivot helper functions .""" + from __future__ import annotations import warnings from collections import abc +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import attr import pandas as pd @@ -191,9 +193,7 @@ def pivot_lookup(*args, **kwargs) -> pd.DataFrame: "Try again with a single row/value as input.", "E.g. func(data=df.iloc[N], column=...)", ) - result_df = _iterate_func( - target_func, input_df, input_column, pivot_reg, **kwargs - ) + result_df = _iterate_func(target_func, input_df, input_column, pivot_reg, **kwargs) else: result_df = target_func(**param_dict, **kwargs) # type: ignore merge_key = pivot_reg.func_out_column_name or input_column @@ -213,16 +213,14 @@ def pivot_lookup(*args, **kwargs) -> pd.DataFrame: ).drop(columns="src_row_index", errors="ignore") return result_df - setattr( - pivot_lookup, - "pivot_properties", - attr.asdict(pivot_reg, filter=(lambda _, val: val is not None)), + pivot_lookup.pivot_properties = attr.asdict( # type: ignore[attr-defined] + pivot_reg, filter=lambda _, val: val is not None ) return pivot_lookup def get_join_params( - func_kwargs: dict[str, Any] + func_kwargs: dict[str, Any], ) -> tuple[str | None, str | None, str | None, bool]: """ Get join parameters from kwargs. @@ -251,7 +249,8 @@ def get_join_params( "If you are specifying explicit join keys " "you must specify 'right_on' parameter with the " + "name of the output column to join on. " - + "Results will joined on index." + + "Results will joined on index.", + stacklevel=2, ) if not left_on: col_keys = list(func_kwargs.keys() - {"start", "end", "data"}) @@ -265,7 +264,8 @@ def get_join_params( "Could not infer 'left' join column from source data. " + "Please specify 'left_on' parameter with the " + "name of the source column to join on. " - + "Results will joined on index." + + "Results will joined on index.", + stacklevel=2, ) return join_type, left_on, right_on, join_ignore_case @@ -312,7 +312,7 @@ def join_result( result_df, left_on=left_on, right_on=right_on, - how=how, # type: ignore + how=how, suffixes=("_src", "_res"), ) @@ -325,7 +325,7 @@ def join_result( result_df, left_on=left_on, right_on=right_on, - how=how, # type: ignore + how=how, suffixes=("_src", "_res"), ).drop(columns=[left_on, right_on]) @@ -388,14 +388,11 @@ def _check_valid_settings_for_input(input_value: Any, pivot_reg: PivotRegistrati isinstance(input_value, pd.DataFrame) or ( # pylint: disable=isinstance-second-argument-not-valid-type - isinstance(input_value, pd.DataFrame) - and not isinstance(input_value, str) + isinstance(input_value, pd.DataFrame) and not isinstance(input_value, str) # pylint: enable=isinstance-second-argument-not-valid-type ) ): - raise ValueError( - f"This function does not accept inputs of {type(input_value)}" - ) + raise ValueError(f"This function does not accept inputs of {type(input_value)}") def _arg_to_dframe(arg_val, col_name: str = "param_value"): @@ -433,7 +430,7 @@ def _create_input_df(input_value, pivot_reg, parent_kwargs): # to using the function input value arg. input_column = pivot_reg.func_df_col_param_name or pivot_reg.func_input_value_arg # If input_value is already a DF, this call just returns the original DF - input_df = _arg_to_dframe(input_value, input_column) # type: ignore + input_df = _arg_to_dframe(input_value, input_column) if isinstance(input_value, pd.DataFrame): # If the original input_value is a DataFrame @@ -480,7 +477,7 @@ def _iterate_func(target_func, input_df, input_column, pivot_reg, **kwargs): results = [] # Add any static parameters to all_rows_kwargs all_rows_kwargs = kwargs.copy() - all_rows_kwargs.update((pivot_reg.func_static_params or {})) + all_rows_kwargs.update(pivot_reg.func_static_params or {}) res_key_col_name = pivot_reg.func_out_column_name or pivot_reg.func_input_value_arg for row_index, row in enumerate(input_df[[input_column]].itertuples(index=False)): diff --git a/msticpy/init/pivot_core/pivot_register_reader.py b/msticpy/init/pivot_core/pivot_register_reader.py index 94a3c14b5..fa26133e8 100644 --- a/msticpy/init/pivot_core/pivot_register_reader.py +++ b/msticpy/init/pivot_core/pivot_register_reader.py @@ -4,11 +4,13 @@ # license information. # -------------------------------------------------------------------------- """Reads pivot registration config files.""" + from __future__ import annotations import importlib import warnings -from typing import Any, Callable, Generator +from collections.abc import Callable, Generator +from typing import Any import yaml @@ -146,13 +148,13 @@ def add_unbound_pivot_function( def _read_reg_file(file_path: str) -> Generator[PivotRegistration, Any, None]: """Read the yaml file and return generator of PivotRegistrations.""" - with open(file_path, "r", encoding="utf-8") as f_handle: + with open(file_path, encoding="utf-8") as f_handle: # use safe_load instead load pivot_regs = yaml.safe_load(f_handle) for entry_name, settings in pivot_regs.get("pivot_providers").items(): try: - yield PivotRegistration( # type: ignore[call-arg] + yield PivotRegistration( src_config_path=file_path, src_config_entry=entry_name, **settings ) except TypeError as err: @@ -204,7 +206,8 @@ def _get_func_from_class(src_module, namespace, piv_reg): except Exception as err: # pylint: disable=broad-except warnings.warn( f"Could not create instance of class {src_class.__name__}. " - + f"Exception was {err}" + + f"Exception was {err}", + stacklevel=2, ) return None # get the function from the object diff --git a/msticpy/init/pivot_init/pivot_data_queries.py b/msticpy/init/pivot_init/pivot_data_queries.py index 5e15537e6..7eb8b8dfc 100644 --- a/msticpy/init/pivot_init/pivot_data_queries.py +++ b/msticpy/init/pivot_init/pivot_data_queries.py @@ -4,11 +4,15 @@ # license information. # -------------------------------------------------------------------------- """Pivot query functions class.""" + +from __future__ import annotations + import itertools import warnings from collections import abc, defaultdict, namedtuple +from collections.abc import Callable, Iterable from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any import pandas as pd @@ -19,14 +23,16 @@ from ..pivot_core.pivot_container import PivotContainer from ..pivot_core.pivot_register import get_join_params, join_result +if TYPE_CHECKING: + from ...data.core.data_providers import QueryProvider + from ...data.core.query_source import QuerySource + __version__ = VERSION __author__ = "Ian Hellen" ParamAttrs = namedtuple("ParamAttrs", "type, query, family, required") -QueryParams = namedtuple( - "QueryParams", "all, required, full_required, param_attrs, table" -) +QueryParams = namedtuple("QueryParams", "all, required, full_required, param_attrs, table") PivQuerySettings = namedtuple( "PivQuerySettings", "short_name, direct_func_entities, assigned_entities" ) @@ -96,8 +102,8 @@ class PivotQueryFunctions: def __init__( self, - query_provider: "QueryProvider", # type: ignore # noqa: F821 - ignore_reqd: List[str] = None, + query_provider: QueryProvider, + ignore_reqd: list[str] | None = None, ): # sourcery skip: remove-unnecessary-cast """ Instantiate PivotQueryFunctions class. @@ -113,8 +119,8 @@ def __init__( """ self.__class__.current = self self._provider = query_provider - self.param_usage: Dict[str, List[ParamAttrs]] = defaultdict(list) - self.query_params: Dict[str, QueryParams] = {} + self.param_usage: dict[str, list[ParamAttrs]] = defaultdict(list) + self.query_params: dict[str, QueryParams] = {} # specify any parameters to exclude from our list ignore_params = set(ignore_reqd) if ignore_reqd else _DEF_IGNORE_PARAM @@ -139,7 +145,7 @@ def __init__( # details of the function/query parameters self.query_params[f"{family}.{src_name}"] = QueryParams( all=list(q_source.params), - required=list((set(q_source.required_params) - ignore_params)), + required=list(set(q_source.required_params) - ignore_params), full_required=list(q_source.required_params), param_attrs={ param: ParamAttrs( @@ -154,7 +160,7 @@ def __init__( ) @property - def instance_name(self) -> Optional[str]: + def instance_name(self) -> str | None: """ Return instance name, if any for provider. @@ -167,9 +173,7 @@ def instance_name(self) -> Optional[str]: """ return self._provider.instance - def get_query_settings( - self, family: str, query: str - ) -> "QuerySource": # type: ignore # noqa: F821 + def get_query_settings(self, family: str, query: str) -> QuerySource: """ Get the QuerySource for the named `family` and `query`. @@ -227,7 +231,7 @@ def get_query_pivot_settings(self, family: str, query: str) -> PivQuerySettings: def get_queries_and_types_for_param( self, param: str - ) -> Iterable[Tuple[str, str, str, Callable[[Any], Any]]]: + ) -> Iterable[tuple[str, str, str, Callable[[Any], Any]]]: """ Get queries and parameter data types for `param`. @@ -257,7 +261,7 @@ def get_queries_and_types_for_param( def get_queries_for_param( self, param: str - ) -> Iterable[Tuple[str, str, Callable[[Any], Any]]]: + ) -> Iterable[tuple[str, str, Callable[[Any], Any]]]: """ Get the list of queries for a parameter. @@ -284,7 +288,7 @@ def get_queries_for_param( ) ] - def get_params(self, query_func_name: str) -> Optional[QueryParams]: + def get_params(self, query_func_name: str) -> QueryParams | None: """ Get the parameters for a query function. @@ -303,7 +307,7 @@ def get_params(self, query_func_name: str) -> Optional[QueryParams]: """ return self.query_params.get(query_func_name) - def get_param_attrs(self, param_name: str) -> List[ParamAttrs]: + def get_param_attrs(self, param_name: str) -> list[ParamAttrs]: """ Get the attributes for a parameter name. @@ -329,7 +333,7 @@ def get_param_attrs(self, param_name: str) -> List[ParamAttrs]: # Map of query parameter names to entities and the entity attrib # corresponding to the query parameter value -PARAM_ENTITY_MAP: Dict[str, List[Tuple[Type[entities.Entity], str]]] = { +PARAM_ENTITY_MAP: dict[str, list[tuple[type[entities.Entity], str]]] = { "account_name": [(entities.Account, "Name")], "host_name": [(entities.Host, "fqdn")], "process_name": [(entities.Process, "ProcessFilePath")], @@ -359,8 +363,8 @@ def get_param_attrs(self, param_name: str) -> List[ParamAttrs]: def add_data_queries_to_entities( - provider: "QueryProvider", # type: ignore # noqa: F821 - get_timespan: Optional[Callable[[], TimeSpan]], + provider: QueryProvider, + get_timespan: Callable[[], TimeSpan] | None, ): """ Add data queries from `provider` to entities. @@ -376,11 +380,7 @@ def add_data_queries_to_entities( """ q_funcs = PivotQueryFunctions(provider) - if ( - provider.instance - and provider.instance != "Default" - and not _use_v1_query_naming() - ): + if provider.instance and provider.instance != "Default" and not _use_v1_query_naming(): container_name = f"{provider.environment}_{provider.instance.casefold()}" else: container_name = provider.environment @@ -399,7 +399,7 @@ def add_data_queries_to_entities( def add_queries_to_entities( prov_qry_funcs: PivotQueryFunctions, container: str, - get_timespan: Optional[Callable[[], TimeSpan]], + get_timespan: Callable[[], TimeSpan] | None, ): """ Add data queries to entities. @@ -440,19 +440,16 @@ def add_queries_to_entities( if param in func_params.all and ent == entity_cls } # Build the map of param names to entity attributes - attr_map = { - param: ent_attr for param, (_, ent_attr) in param_entities.items() - } + attr_map = {param: ent_attr for param, (_, ent_attr) in param_entities.items()} # Wrap the function cls_func = _create_pivot_func( - func, func_params.param_attrs, attr_map, get_timespan # type:ignore + func, + func_params.param_attrs, + attr_map, + get_timespan, # type: ignore[arg-type] ) # add a properties dict to the function - setattr( - cls_func, - "pivot_properties", - _create_piv_properties(name, param_entities, container), - ) + cls_func.pivot_properties = _create_piv_properties(name, param_entities, container) q_piv_settings = prov_qry_funcs.get_query_pivot_settings(family, name) func_name = _format_func_name(name, family, func_params, q_piv_settings) @@ -477,7 +474,7 @@ def add_queries_to_entities( def _get_pivot_instance(): """Get the timespan access function from Pivot global instance.""" # pylint: disable=import-outside-toplevel, cyclic-import - from ..pivot import Pivot + from ..pivot import Pivot # noqa: PLC0415 return Pivot() @@ -506,8 +503,8 @@ def _create_piv_properties(name, param_entities, container): def _create_pivot_func( func: Callable[[Any], pd.DataFrame], - func_params: Dict[str, ParamAttrs], - param_attrib_map: Dict[str, str], + func_params: dict[str, ParamAttrs], + param_attrib_map: dict[str, str], get_timespan: Callable[[], TimeSpan], ): """ @@ -570,7 +567,7 @@ def wrapped_query_func(*args, **kwargs): def _create_data_func_exec( - func: Callable[[Any], pd.DataFrame], func_params: Dict[str, ParamAttrs] + func: Callable[[Any], pd.DataFrame], func_params: dict[str, ParamAttrs] ) -> Callable[[Any], pd.DataFrame]: """ Wrap func to issue single or multiple calls to query. @@ -640,7 +637,8 @@ def call_data_query(**kwargs): warnings.warn( "Cannot do an index merge on this result set. " + "Please use an explicit column join using 'left_on' " - + "and 'right_on' join columns." + + "and 'right_on' parameters.", + stacklevel=2, ) return result_df.drop(columns="src_row_index", errors="ignore") # The inputs are some mix of simple values and/or iterables. @@ -681,14 +679,14 @@ def _exec_query_for_df(func, func_kwargs, func_params, parent_kwargs): def _check_df_params_require_iter( - func_params: Dict[str, ParamAttrs], + func_params: dict[str, ParamAttrs], src_df: pd.DataFrame, - func_kwargs: Dict[str, Any], + func_kwargs: dict[str, Any], **kwargs, -) -> Tuple[Dict[str, Any], Dict[str, Any]]: +) -> tuple[dict[str, Any], dict[str, Any]]: """Return params that require iteration and those that don't.""" - list_params: Dict[str, Any] = {} - df_iter_params: Dict[str, Any] = {} + list_params: dict[str, Any] = {} + df_iter_params: dict[str, Any] = {} for kw_name, arg in kwargs.items(): if kw_name in _DEF_IGNORE_PARAM: continue @@ -727,7 +725,7 @@ def _exec_query_for_values(func, func_kwargs, func_params, parent_kwargs): # iteration so ignore these and run queries per row row_results = [] # zip the value lists into tuples - for row in zip(*(var_iter_params.values())): + for row in zip(*(var_iter_params.values()), strict=False): # build a single-line dict of {param1: row_value1...} col_param_dict = {param: row[idx] for idx, param in enumerate(var_iter_params)} row_results.append(func(**simple_params, **col_param_dict, **func_kwargs)) @@ -735,11 +733,11 @@ def _exec_query_for_values(func, func_kwargs, func_params, parent_kwargs): def _check_var_params_require_iter( - func_params: Dict[str, ParamAttrs], func_kwargs: Dict[str, Any], **kwargs -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + func_params: dict[str, ParamAttrs], func_kwargs: dict[str, Any], **kwargs +) -> tuple[dict[str, Any], dict[str, Any]]: """Return params that require iteration and don't.""" - simple_params: Dict[str, Any] = {} - var_iter_params: Dict[str, Any] = {} + simple_params: dict[str, Any] = {} + var_iter_params: dict[str, Any] = {} for kw_name, arg in kwargs.items(): if kw_name in _DEF_IGNORE_PARAM: continue diff --git a/msticpy/init/pivot_init/pivot_ti_provider.py b/msticpy/init/pivot_init/pivot_ti_provider.py index 3cbad2d79..90dfd3013 100644 --- a/msticpy/init/pivot_init/pivot_ti_provider.py +++ b/msticpy/init/pivot_init/pivot_ti_provider.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Pivot TI Provider helper functions.""" + from collections import defaultdict -from typing import Callable, Dict, Set, Tuple, Type +from collections.abc import Callable import pandas as pd @@ -21,7 +22,7 @@ IOC_TYPES = {"ipv4", "ipv6", "dns", "file_hash", "url"} -TI_ENTITY_ATTRIBS: Dict[str, Tuple[Type, str]] = { +TI_ENTITY_ATTRIBS: dict[str, tuple[type, str]] = { "ipv4": (entities.IpAddress, "Address"), "ipv6": (entities.IpAddress, "Address"), "ip": (entities.IpAddress, "Address"), @@ -68,7 +69,7 @@ def add_ioc_queries_to_entities(ti_lookup: TILookup, container: str = "ti", **kw def create_ti_pivot_funcs(ti_lookup: TILookup): """Create the TI Pivot functions.""" ioc_type_supp = _get_supported_ioc_types(ti_lookup) - ioc_queries: Dict[str, Dict[str, Callable[..., pd.DataFrame]]] = defaultdict(dict) + ioc_queries: dict[str, dict[str, Callable[..., pd.DataFrame]]] = defaultdict(dict) # Add functions for ioc types that will call all providers # Non-IP types @@ -86,7 +87,7 @@ def register_ti_pivot_providers(ti_lookup: TILookup, pivot: "Pivot"): # type: i ti_prov.register_pivots(PivotRegistration, pivot) -def _get_supported_ioc_types(ti_lookup: TILookup) -> Dict[str, Set[str]]: +def _get_supported_ioc_types(ti_lookup: TILookup) -> dict[str, set[str]]: return { ti_prov_name: set(ti_prov.supported_types) & IOC_TYPES for ti_prov_name, ti_prov in ti_lookup.loaded_providers.items() @@ -95,7 +96,7 @@ def _get_supported_ioc_types(ti_lookup: TILookup) -> Dict[str, Set[str]]: def _create_lookup_func( ti_lookup: TILookup, ioc, ioc_name, providers -) -> Tuple[str, str, Callable[..., pd.DataFrame]]: +) -> tuple[str, str, Callable[..., pd.DataFrame]]: suffix = f"_{ioc_name}" short_func_name = f"lookup{suffix}" func_name = f"{short_func_name}_{ioc_name}" @@ -104,7 +105,7 @@ def _create_lookup_func( # use IoC name if ioc_type is None entity_cls, entity_attr = TI_ENTITY_ATTRIBS[ioc or ioc_name] - pivot_reg = PivotRegistration( # type: ignore[call-arg] + pivot_reg = PivotRegistration( src_func_name=ti_lookup.lookup_iocs.__name__, input_type="dataframe", entity_map={entity_cls.__name__: entity_attr}, diff --git a/msticpy/init/pivot_init/vt_pivot.py b/msticpy/init/pivot_init/vt_pivot.py index dbafd7f84..098a6c365 100644 --- a/msticpy/init/pivot_init/vt_pivot.py +++ b/msticpy/init/pivot_init/vt_pivot.py @@ -7,7 +7,6 @@ from enum import Flag, auto from functools import partial -from typing import Dict, Optional, Tuple, Union from ..._version import VERSION from ...common.provider_settings import get_provider_settings @@ -84,7 +83,7 @@ class VTAPIScope(Flag): "referrer_urls": VTAPIScope.PRIVATE, } -PIVOT_ENTITY_CATS: Dict[str, Tuple[str, Dict[str, VTAPIScope]]] = { +PIVOT_ENTITY_CATS: dict[str, tuple[str, dict[str, VTAPIScope]]] = { "File": ("file", FILE_RELATIONSHIPS), "IpAddress": ("ip_address", IP_RELATIONSHIPS), "Dns": ("domain", DOMAIN_RELATIONSHIPS), @@ -126,7 +125,7 @@ def init(): # pylint: disable=no-member -def add_pivot_functions(api_scope: Optional[str] = None): +def add_pivot_functions(api_scope: str | None = None): """ Add VT3 relationship functions as pivot functions. @@ -150,7 +149,7 @@ def add_pivot_functions(api_scope: Optional[str] = None): # pylint: disable=no-member -def _create_pivots(api_scope: Union[str, VTAPIScope, None]): +def _create_pivots(api_scope: str | VTAPIScope | None): if api_scope is None: scope = _get_vt_api_scope() elif isinstance(api_scope, str): @@ -175,9 +174,7 @@ def _create_pivots(api_scope: Union[str, VTAPIScope, None]): scope = VTAPIScope.ALL ent_funcs = {} for entity, (vt_type, category) in PIVOT_ENTITY_CATS.items(): - ent_relations = ( - rel for rel, rel_scope in category.items() if rel_scope & scope - ) + ent_relations = (rel for rel, rel_scope in category.items() if rel_scope & scope) func_dict = {} for relationship in ent_relations: f_part = partial( diff --git a/msticpy/init/user_config.py b/msticpy/init/user_config.py index 28ac6f0c1..dc1067eb0 100644 --- a/msticpy/init/user_config.py +++ b/msticpy/init/user_config.py @@ -44,10 +44,11 @@ is to connect after loading. You can skip the connect step by add connect: False to the entry. """ + import textwrap from contextlib import redirect_stdout from io import StringIO -from typing import Any, Dict, Tuple +from typing import Any from .._version import VERSION from ..common.pkg_config import get_config @@ -58,7 +59,7 @@ __author__ = "Ian Hellen" -def load_user_defaults() -> Dict[str, object]: +def load_user_defaults() -> dict[str, object]: """ Load providers from user defaults in msticpyconfig.yaml. @@ -132,9 +133,7 @@ def _load_components(user_defaults, namespace=None): return prov_dict -def _load_az_workspaces( - prov_name: str, azsent_prov_entry: Dict[str, Any] -) -> Dict[str, Any]: +def _load_az_workspaces(prov_name: str, azsent_prov_entry: dict[str, Any]) -> dict[str, Any]: az_provs = {} for ws_name, ws_settings in azsent_prov_entry.items(): if not ws_settings: @@ -157,7 +156,7 @@ def _load_az_workspaces( return az_provs -def _load_provider(prov_name: str, qry_prov_entry: Dict[str, Any]) -> Tuple[str, Any]: +def _load_provider(prov_name: str, qry_prov_entry: dict[str, Any]) -> tuple[str, Any]: alias = qry_prov_entry.get("alias", prov_name) connect = qry_prov_entry.get("connect", True) obj_name = f"qry_{alias.lower()}" @@ -172,22 +171,20 @@ def _load_provider(prov_name: str, qry_prov_entry: Dict[str, Any]) -> Tuple[str, # pylint: disable=import-outside-toplevel def _load_ti_lookup(comp_settings=None, **kwargs): del comp_settings, kwargs - from ..context.tilookup import TILookup + from ..context.tilookup import TILookup # noqa: PLC0415 return "ti_lookup", TILookup() def _load_geoip_lookup(comp_settings=None, **kwargs): del kwargs - provider = ( - comp_settings.get("provider") if isinstance(comp_settings, dict) else None - ) + provider = comp_settings.get("provider") if isinstance(comp_settings, dict) else None if provider == "GeoLiteLookup": - from ..context.geoip import GeoLiteLookup + from ..context.geoip import GeoLiteLookup # noqa: PLC0415 return "geoip", GeoLiteLookup() if provider == "IpStackLookup": - from ..context.geoip import IPStackLookup + from ..context.geoip import IPStackLookup # noqa: PLC0415 return "geoip", IPStackLookup() return None, None @@ -196,9 +193,7 @@ def _load_geoip_lookup(comp_settings=None, **kwargs): def _load_notebooklets(comp_settings=None, **kwargs): nbinit_params = {} if comp_settings and isinstance(comp_settings, dict): - prov_name, prov_args = next( - iter(comp_settings.get("query_provider", {}).items()) - ) + prov_name, prov_args = next(iter(comp_settings.get("query_provider", {}).items())) if prov_name: nbinit_params = {"query_provider": prov_name} if prov_args: @@ -212,7 +207,7 @@ def _load_notebooklets(comp_settings=None, **kwargs): providers = [f"+{prov}" for prov in providers] nbinit_params.update({"providers": providers, "namespace": namespace}) try: - import msticnb + import msticnb # noqa: PLC0415 msticnb.init(**nbinit_params) return "nb", msticnb @@ -225,7 +220,7 @@ def _load_notebooklets(comp_settings=None, **kwargs): def _load_azure_data(comp_settings=None, **kwargs): del kwargs - from ..context.azure.azure_data import AzureData + from ..context.azure.azure_data import AzureData # noqa: PLC0415 az_data = AzureData() connect = comp_settings.pop("connect", True) @@ -238,7 +233,7 @@ def _load_azure_data(comp_settings=None, **kwargs): def _load_azsent_api(comp_settings=None, **kwargs): del kwargs - from ..context.azure.sentinel_core import MicrosoftSentinel + from ..context.azure.sentinel_core import MicrosoftSentinel # noqa: PLC0415 res_id = comp_settings.pop("res_id", None) if res_id: diff --git a/msticpy/lazy_importer.py b/msticpy/lazy_importer.py index 1c0452f03..4f9632f89 100644 --- a/msticpy/lazy_importer.py +++ b/msticpy/lazy_importer.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Lazy importer for msticpy sub-packages.""" + import importlib +from collections.abc import Callable, Iterable from types import ModuleType -from typing import Callable, Iterable, Tuple from ._version import VERSION @@ -16,7 +17,7 @@ def lazy_import( importer_name: str, import_list: Iterable[str] -) -> Tuple[ModuleType, Callable, Callable]: +) -> tuple[ModuleType, Callable, Callable]: """ Return the importing module and a callable for lazy importing. @@ -71,7 +72,8 @@ def __getattr__(name: str): # appropriate for direct imports. try: imported = importlib.import_module( - mod_name, module.__spec__.parent # type: ignore + mod_name, + module.__spec__.parent, # type: ignore ) except ImportError as imp_err: message = f"cannot import name '{mod_name}' from '{importer_name}'" diff --git a/msticpy/nbtools/__init__.py b/msticpy/nbtools/__init__.py index 327f240a9..06231acc5 100644 --- a/msticpy/nbtools/__init__.py +++ b/msticpy/nbtools/__init__.py @@ -3,89 +3,4 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -""" -nbtools module - Notebook Security Tools. - -This is a collection of modules with functionality (mostly) specific to -notebooks. It also houses some visualization modules that will migrate -to the vis sub-package. - -- nbinit - notebook initialization -- azure_ml_tools - configuration and helpers for AML workspaces -- nbwidgets - ipywidgets-based UI components for infosec notebooks -- nbdisplay - miscellaneous display functions TBM to vis - -""" -# flake8: noqa: F403 -# pylint: disable=W0401 -# import importlib -# from typing import Any - -# from .. import nbwidgets -from .._version import VERSION -from ..lazy_importer import lazy_import - -# from ..common import utility as utils -# from ..common.wsconfig import WorkspaceConfig -# from ..vis import nbdisplay -# from .security_alert import SecurityAlert - -# try: -# from IPython import get_ipython - -# from ..init import nbmagics -# except ImportError as err: -# pass - -# pylint: enable=W0401 - -__version__ = VERSION - -# _DEFAULT_IMPORTS = {"nbinit": "msticpy.init.nbinit"} - -_LAZY_IMPORTS = { - "msticpy.init.nbinit", - "msticpy.common.utility as utils", - "msticpy.common.wsconfig.WorkspaceConfig", - "msticpy.nbtools.security_alert.SecurityAlert", - "msticpy.nbwidgets", - "msticpy.vis.nbdisplay", -} - -# def __getattr__(attrib: str) -> Any: -# """ -# Import and return an attribute of nbtools. - -# Parameters -# ---------- -# attrib : str -# The attribute name - -# Returns -# ------- -# Any -# The attribute value. - -# Raises -# ------ -# AttributeError -# No attribute found. - -# """ -# if attrib in _DEFAULT_IMPORTS: -# module = importlib.import_module(_DEFAULT_IMPORTS[attrib]) -# return module -# raise AttributeError(f"msticpy has no attribute {attrib}") - -# from .vtlookupv3 import VT3_AVAILABLE - -# vtlookupv3: Any -# if VT3_AVAILABLE: -# from .vtlookupv3 import vtlookupv3 -# else: -# # vtlookup3 will not load if vt package not installed -# vtlookupv3 = ImportPlaceholder( # type: ignore -# "vtlookupv3", ["vt-py", "vt-graph-api", "nest_asyncio"] -# ) - -module, __getattr__, __dir__ = lazy_import(__name__, _LAZY_IMPORTS) +"""NBTools package.""" diff --git a/msticpy/nbtools/data_viewer.py b/msticpy/nbtools/data_viewer.py deleted file mode 100644 index 86c0ca787..000000000 --- a/msticpy/nbtools/data_viewer.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module data_viewer.py has moved. - -See :py:mod:`msticpy.vis.data_viewer` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.data_viewer import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.data_viewer\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/entityschema.py b/msticpy/nbtools/entityschema.py deleted file mode 100644 index e9cfa5ee1..000000000 --- a/msticpy/nbtools/entityschema.py +++ /dev/null @@ -1,14 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Placeholder for old entity_schema module.""" - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Ian Hellen" - -# pylint: disable=wildcard-import, unused-wildcard-import -from ..datamodel.entities import * # noqa: F403, F401 diff --git a/msticpy/nbtools/foliummap.py b/msticpy/nbtools/foliummap.py deleted file mode 100644 index 0eec1f00a..000000000 --- a/msticpy/nbtools/foliummap.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module foliummap.py has moved. - -See :py:mod:`msticpy.vis.foliummap` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.foliummap import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.foliummap\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/morph_charts.py b/msticpy/nbtools/morph_charts.py deleted file mode 100644 index 6a82ac20e..000000000 --- a/msticpy/nbtools/morph_charts.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module morph_charts.py has moved. - -See :py:mod:`msticpy.vis.morph_charts` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.morph_charts import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.morph_charts\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/nbdisplay.py b/msticpy/nbtools/nbdisplay.py deleted file mode 100644 index fcaf4159e..000000000 --- a/msticpy/nbtools/nbdisplay.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module nbdisplay.py has moved. - -See :py:mod:`msticpy.vis.nbdisplay` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.nbdisplay import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.nbdisplay\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/nbwidgets.py b/msticpy/nbtools/nbwidgets.py deleted file mode 100644 index 2d1f73f62..000000000 --- a/msticpy/nbtools/nbwidgets.py +++ /dev/null @@ -1,27 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module nbtools.nbwidgets has moved. - -See :py:mod:`msticpy.nbwidgets` -""" -import warnings - -from .._version import VERSION - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..nbwidgets import * # noqa: F401 - -__version__ = VERSION -__author__ = "Ian Hellen" - -WARN_MSSG = ( - "This module has moved to msticpy.nbwidgets\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/observationlist.py b/msticpy/nbtools/observationlist.py deleted file mode 100644 index 1e6bcad6c..000000000 --- a/msticpy/nbtools/observationlist.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module observationlist.py has moved. - -See :py:mod:`msticpy.analysis.observationlist` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..analysis.observationlist import * - -WARN_MSSG = ( - "This module has moved to msticpy.analysis.observationlist\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/process_tree.py b/msticpy/nbtools/process_tree.py deleted file mode 100644 index ee2fd5169..000000000 --- a/msticpy/nbtools/process_tree.py +++ /dev/null @@ -1,42 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module process_tree.py has moved. - -See :py:mod:`msticpy.vis.process_tree` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.process_tree_utils import get_ancestors # noqa F401 -from ..transform.process_tree_utils import ( - get_children, - get_descendents, - get_parent, - get_process, - get_process_key, - get_root, - get_root_tree, - get_roots, - get_siblings, - get_summary_info, - get_tree_depth, -) -from ..vis.process_tree import * -from ..vis.process_tree import build_process_tree, infer_schema - -WARN_MSSG = ( - "This module has moved to msticpy.vis.process_tree\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/security_alert.py b/msticpy/nbtools/security_alert.py index e90cad9cd..fe1c64211 100644 --- a/msticpy/nbtools/security_alert.py +++ b/msticpy/nbtools/security_alert.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Module for SecurityAlert class.""" + import json from json import JSONDecodeError -from typing import Any, Dict, List +from typing import Any import pandas as pd from deprecated.sphinx import deprecated @@ -35,9 +36,9 @@ def __init__(self, src_row: pd.Series = None): super().__init__(src_row=src_row) # add entities to dictionary to remove dups - self._src_entities: Dict[int, Entity] = {} + self._src_entities: dict[int, Entity] = {} - self.extended_properties: Dict[str, Any] = {} + self.extended_properties: dict[str, Any] = {} if src_row is not None: if "Entities" in src_row: self._extract_entities(src_row) @@ -47,20 +48,18 @@ def __init__(self, src_row: pd.Series = None): self.extended_properties = src_row.ExtendedProperties elif isinstance(src_row.ExtendedProperties, str): try: - self.extended_properties = json.loads( - src_row.ExtendedProperties - ) + self.extended_properties = json.loads(src_row.ExtendedProperties) except JSONDecodeError: pass self._find_os_family() @property - def entities(self) -> List[Entity]: + def entities(self) -> list[Entity]: """Return a list of the Security Alert entities.""" return list(self._src_entities.values()) @property - def query_params(self) -> Dict[str, Any]: + def query_params(self) -> dict[str, Any]: """ Query parameters derived from alert. @@ -71,10 +70,7 @@ def query_params(self) -> Dict[str, Any]: """ params_dict = super().query_params - if ( - "system_alert_id" not in params_dict - or params_dict["system_alert_id"] is None - ): + if "system_alert_id" not in params_dict or params_dict["system_alert_id"] is None: params_dict["system_alert_id"] = self._ids["SystemAlertId"] return params_dict @@ -106,8 +102,7 @@ def __str__(self): if self.extended_properties: str_rep = [ - f"ExtProp: {prop}: {val}" - for prop, val in self.extended_properties.items() + f"ExtProp: {prop}: {val}" for prop, val in self.extended_properties.items() ] alert_props = alert_props + "\n" + "\n".join(str_rep) @@ -144,8 +139,7 @@ def _resolve_entity_refs(self): ref_props_multi = { name: prop for name, prop in entity.properties.items() - if isinstance(prop, list) - and any(elem for elem in prop if "$ref" in elem) + if isinstance(prop, list) and any(elem for elem in prop if "$ref" in elem) } for prop_name, prop_val in ref_props_multi.items(): for idx, elem in enumerate(prop_val): @@ -159,7 +153,7 @@ def _resolve_entity_refs(self): edge_attrs={"name": prop_name}, ) - def _extract_entities(self, src_row): # noqa: MC0001 + def _extract_entities(self, src_row): input_entities = [] if isinstance(src_row.ExtendedProperties, str): diff --git a/msticpy/nbtools/security_alert_graph.py b/msticpy/nbtools/security_alert_graph.py index c55089e36..e4d5511a8 100644 --- a/msticpy/nbtools/security_alert_graph.py +++ b/msticpy/nbtools/security_alert_graph.py @@ -8,6 +8,7 @@ Creates an entity graph for the alert. """ + import networkx as nx import pandas as pd @@ -109,9 +110,7 @@ def add_related_alerts(related_alerts: pd.DataFrame, alertgraph: nx.Graph) -> nx related_alerts.apply(lambda x: _add_alert_node(related_alerts_graph, x), axis=1) if alert_host_node: related_alerts.apply( - lambda x: _add_related_alert_edges( - related_alerts_graph, x, alert_host_node - ), + lambda x: _add_related_alert_edges(related_alerts_graph, x, alert_host_node), axis=1, ) return related_alerts_graph @@ -207,8 +206,8 @@ def _get_name_and_description(entity, os_family="Windows"): e_name, e_description = _get_account_name_desc(entity) elif entity["Type"] == "host-logon-session": e_name = "host-logon-session" - e_description = f'Logon session {entity["SessionId"]}\n' - e_description = e_description + f'(Start time: {entity["StartTimeUtc"]}' + e_description = f"Logon session {entity['SessionId']}\n" + e_description = e_description + f"(Start time: {entity['StartTimeUtc']}" elif entity["Type"] == "process": e_name, e_description = _get_process_name_desc(entity) elif entity["Type"] == "file": @@ -275,11 +274,7 @@ def _get_file_name_desc(entity): def _get_process_name_desc(entity): if "ProcessFilePath" in entity: path = entity.ProcessFilePath - elif ( - "ImageFile" in entity - and entity["ImageFile"] - and "FullPath" in entity["ImageFile"] - ): + elif "ImageFile" in entity and entity["ImageFile"] and "FullPath" in entity["ImageFile"]: path = entity["ImageFile"]["FullPath"] else: path = "unknown" diff --git a/msticpy/nbtools/security_base.py b/msticpy/nbtools/security_base.py index b6c0a293a..6dca19824 100644 --- a/msticpy/nbtools/security_base.py +++ b/msticpy/nbtools/security_base.py @@ -4,13 +4,14 @@ # license information. # -------------------------------------------------------------------------- """Module for SecurityAlert class.""" + from __future__ import annotations import html import re from collections import Counter from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any import pandas as pd from deprecated.sphinx import deprecated @@ -23,7 +24,7 @@ __version__ = VERSION __author__ = "Ian Hellen" -_ID_PROPERTIES: List[str] = [ +_ID_PROPERTIES: list[str] = [ "AzSubscriptionId", "AzResourceId", "WorkspaceId", @@ -55,11 +56,11 @@ def __init__(self, src_row: pd.Series = None): self._source_data: pd.Series = ( src_row if src_row is not None else pd.Series(dtype="object") ) - self._custom_query_params: Dict[str, Any] = {} - self._entities: List[Entity] = [] + self._custom_query_params: dict[str, Any] = {} + self._entities: list[Entity] = [] # Extract and cache alert ID properties - self._ids: Dict[str, str] = {} + self._ids: dict[str, str] = {} if self._source_data is not None: for id_property in _ID_PROPERTIES: if id_property in self._source_data: @@ -104,9 +105,7 @@ def __str__(self): def __repr__(self) -> str: """Return repr of item.""" if self.properties: - params = ", ".join( - [f"{name}={val}" for name, val in self.properties.items()] - ) + params = ", ".join([f"{name}={val}" for name, val in self.properties.items()]) if len(params) > 80: params = params[:80] + "..." return f"{self.__class__.__name__}({params})" @@ -134,7 +133,7 @@ def _repr_html_(self) -> str: # Properties @property - def entities(self) -> List[Entity]: + def entities(self) -> list[Entity]: """ Return a list of the Alert or Event entities. @@ -147,7 +146,7 @@ def entities(self) -> List[Entity]: return self._entities @property - def properties(self) -> Dict[str, Any]: + def properties(self) -> dict[str, Any]: """ Return a dictionary of the Alert or Event properties. @@ -165,7 +164,7 @@ def hostname(self) -> str | None: return self.primary_host.HostName if self.primary_host is not None else None @property - def computer(self) -> Optional[str]: + def computer(self) -> str | None: """ Return the Computer name of the host associated with the alert. @@ -174,7 +173,7 @@ def computer(self) -> Optional[str]: return self.primary_host.computer if self.primary_host is not None else None @property - def ids(self) -> Dict[str, str]: + def ids(self) -> dict[str, str]: """Return a collection of Identity properties for the alert.""" return self._ids @@ -205,7 +204,7 @@ def is_in_azure_sub(self) -> bool: return "AzSubscriptionId" in self._ids and "AzResourceId" in self._ids @property - def primary_host(self) -> Optional[Union[Host, Entity]]: + def primary_host(self) -> Host | Entity | None: """ Return the primary host entity (if any) associated with this object. @@ -221,7 +220,7 @@ def primary_host(self) -> Optional[Union[Host, Entity]]: return None @property - def primary_process(self) -> Optional[Union[Process, Entity]]: + def primary_process(self) -> Process | Entity | None: """ Return the primary process entity (if any) associated with this object. @@ -248,7 +247,7 @@ def primary_process(self) -> Optional[Union[Process, Entity]]: return procs_with_parent[0] if procs_with_parent else procs[0] @property - def primary_account(self) -> Optional[Union[Process, Entity]]: + def primary_account(self) -> Process | Entity | None: """ Return the primary account entity (if any) associated with this object. @@ -262,7 +261,7 @@ def primary_account(self) -> Optional[Union[Process, Entity]]: return accts[0] if accts else None @property - def query_params(self) -> Dict[str, Any]: + def query_params(self) -> dict[str, Any]: """ Query parameters derived from alert. @@ -327,7 +326,7 @@ def origin_time(self) -> datetime: """Return the datetime of event.""" return self.TimeGenerated - def get_logon_id(self, account: Account = None) -> Optional[Union[str, int]]: + def get_logon_id(self, account: Account = None) -> str | int | None: """ Get the logon Id for the alert or the account, if supplied. @@ -346,9 +345,7 @@ def get_logon_id(self, account: Account = None) -> Optional[Union[str, int]]: """ for session in [ - e - for e in self.entities - if e["Type"] in ["host-logon-session", "hostlogonsession"] + e for e in self.entities if e["Type"] in ["host-logon-session", "hostlogonsession"] ]: if account is None or session["Account"] == account: return session["SessionId"] @@ -366,10 +363,7 @@ def subscription_filter(self, operator="=="): if self.is_in_log_analytics: return "true" if self.is_in_azure_sub: - return ( - f"AzureResourceSubscriptionId {operator} " - f"'{self._ids['AzSubscriptionId']}'" - ) + return f"AzureResourceSubscriptionId {operator} '{self._ids['AzSubscriptionId']}'" if self.is_in_workspace: return f"WorkspaceId {operator} '{self._ids['WorkspaceId']}'" @@ -403,7 +397,7 @@ def host_filter(self, operator="=="): return f"AgentId {operator} '{self._ids['AgentId']}'" return None - def get_entities_of_type(self, entity_type: str) -> List[Entity]: + def get_entities_of_type(self, entity_type: str) -> list[Entity]: """ Return entity collection for a give entity type. @@ -469,9 +463,7 @@ def to_html(self, show_entities: bool = False) -> str: if show_entities and self.entities: entity_title = "

Entities:


" - entity_html = "
".join( - [self._format_entity(ent) for ent in self.entities] - ) + entity_html = "
".join([self._format_entity(ent) for ent in self.entities]) html_doc = html_doc + entity_title + entity_html else: e_counts = Counter([ent["Type"] for ent in self.entities]) @@ -510,9 +502,7 @@ def _find_os_family(self): break else: for proc in [ - e - for e in self.entities - if e["Type"] == "process" and "ImageFile" in e + e for e in self.entities if e["Type"] == "process" and "ImageFile" in e ]: file = proc["ImageFile"] if "Directory" in file and "/" in file["Directory"]: @@ -521,7 +511,7 @@ def _find_os_family(self): break @staticmethod - def _get_subscription_from_resource(resource_id) -> Optional[str]: + def _get_subscription_from_resource(resource_id) -> str | None: """Extract subscription Id from resource string.""" sub_regex = r"^/subscriptions/([^/]+)/" sub_ids = re.findall(sub_regex, resource_id, re.RegexFlag.I) diff --git a/msticpy/nbtools/security_event.py b/msticpy/nbtools/security_event.py index e9478498f..9a9f17fba 100644 --- a/msticpy/nbtools/security_event.py +++ b/msticpy/nbtools/security_event.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Module for SecurityEvent class.""" -from typing import Any, Dict, List + +from typing import Any import pandas as pd from deprecated.sphinx import deprecated @@ -36,7 +37,7 @@ def __init__(self, src_row: pd.Series = None): :param src_row: Pandas series containing single security event """ - self._source_data = src_row # type: ignore + self._source_data = src_row super().__init__(src_row=src_row) @@ -45,7 +46,7 @@ def __init__(self, src_row: pd.Series = None): # Properties @property - def entities(self) -> List[Entity]: + def entities(self) -> list[Entity]: """ Return the list of entities extracted from the event. @@ -58,7 +59,7 @@ def entities(self) -> List[Entity]: return list(self._entities) @property - def query_params(self) -> Dict[str, Any]: + def query_params(self) -> dict[str, Any]: """ Query parameters derived from alert. diff --git a/msticpy/nbtools/ti_browser.py b/msticpy/nbtools/ti_browser.py deleted file mode 100644 index 56a4f74bb..000000000 --- a/msticpy/nbtools/ti_browser.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module ti_browser.py has moved. - -See :py:mod:`msticpy.vis.ti_browser` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.ti_browser import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.ti_browser\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/timeline.py b/msticpy/nbtools/timeline.py deleted file mode 100644 index 5c5938fe1..000000000 --- a/msticpy/nbtools/timeline.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module timeline.py has moved. - -See :py:mod:`msticpy.vis.timeline` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.timeline import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.timeline\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/timeline_duration.py b/msticpy/nbtools/timeline_duration.py deleted file mode 100644 index 1c0169086..000000000 --- a/msticpy/nbtools/timeline_duration.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module timeline_duration.py has moved. - -See :py:mod:`msticpy.vis.timeline_duration` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.timeline_duration import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.timeline_duration\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/timeline_pd_accessor.py b/msticpy/nbtools/timeline_pd_accessor.py deleted file mode 100644 index 3e0909f2b..000000000 --- a/msticpy/nbtools/timeline_pd_accessor.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module timeline_pd_accessor.py has moved. - -See :py:mod:`msticpy.vis.timeline_pd_accessor` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.timeline_pd_accessor import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.timeline_pd_accessor\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/timeseries.py b/msticpy/nbtools/timeseries.py deleted file mode 100644 index b5a73c4aa..000000000 --- a/msticpy/nbtools/timeseries.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module timeseries.py has moved. - -See :py:mod:`msticpy.vis.timeseries` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..vis.timeseries import * - -WARN_MSSG = ( - "This module has moved to msticpy.vis.timeseries\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/utility.py b/msticpy/nbtools/utility.py deleted file mode 100644 index c22f7e14f..000000000 --- a/msticpy/nbtools/utility.py +++ /dev/null @@ -1,19 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Deprecated path for common.utility.py.""" -import warnings - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import -from ..common.utility import * -from ..common.utility import md, md_warn - -WARN_MSSG = ( - "This module has moved to msticpy.common.utility\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbtools/wsconfig.py b/msticpy/nbtools/wsconfig.py deleted file mode 100644 index 78f31bd17..000000000 --- a/msticpy/nbtools/wsconfig.py +++ /dev/null @@ -1,18 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Deprecated path for common.wsconfig.py.""" -import warnings - -# flake8: noqa: F401 -# pylint: disable=unused-import -from ..common.wsconfig import WorkspaceConfig - -WARN_MSSG = ( - "This module has moved to msticpy.common.wsconfig\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/nbwidgets/core.py b/msticpy/nbwidgets/core.py index 7b2c66146..d93658f3b 100644 --- a/msticpy/nbwidgets/core.py +++ b/msticpy/nbwidgets/core.py @@ -4,9 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + +from __future__ import annotations + from abc import ABC from enum import IntEnum -from typing import Any, Dict, List, Optional, ClassVar +from typing import Any, ClassVar from weakref import WeakValueDictionary from IPython.display import display @@ -21,7 +24,7 @@ # pylint: disable=too-few-public-methods -class RegisteredWidget(ABC): +class RegisteredWidget(ABC): # noqa: B024 """ Register widget in the widget registry. @@ -34,21 +37,21 @@ class RegisteredWidget(ABC): the same cell after entering values. """ - ALLOWED_KWARGS: ClassVar[List[str]] = [ + ALLOWED_KWARGS: ClassVar[list[str]] = [ "id_vals", "val_attrs", "nb_params", "name_space", "register", ] - _NB_PARAMS: ClassVar[Dict[str, str]] = {} + _NB_PARAMS: ClassVar[dict[str, str]] = {} def __init__( self, - id_vals: Optional[List[Any]] = None, - val_attrs: Optional[List[str]] = None, - nb_params: Optional[Dict[str, str]] = None, - name_space: Dict[str, Any] = globals(), + id_vals: list[Any] | None = None, + val_attrs: list[str] | None = None, + nb_params: dict[str, str] | None = None, + name_space: dict[str, Any] = globals(), # noqa: B008 register: bool = True, **kwargs, ): @@ -99,9 +102,7 @@ def __init__( # one that was recovered from the widget registry # set it from the nb_param value wgt_internal_name = self._NB_PARAMS.get(attr, attr) - if nb_param in name_space and not getattr( - self, wgt_internal_name, None - ): + if nb_param in name_space and not getattr(self, wgt_internal_name, None): setattr(self, wgt_internal_name, name_space[nb_param]) @@ -145,7 +146,7 @@ def parse_time_unit(unit_str: str) -> TimeUnit: return TimeUnit.MINUTE -def default_max_buffer(max_default: Optional[int], default: int, unit: TimeUnit) -> int: +def default_max_buffer(max_default: int | None, default: int, unit: TimeUnit) -> int: """Return the max time buffer for a give time unit.""" mag_default = abs(int(default * 4)) if max_default is not None: @@ -160,7 +161,7 @@ def default_max_buffer(max_default: Optional[int], default: int, unit: TimeUnit) return max(240, mag_default) -def default_before_after(default: Optional[int], unit: TimeUnit) -> int: +def default_before_after(default: int | None, unit: TimeUnit) -> int: """Return default before and after bounds for a TimeUnit.""" if default is not None: return abs(default) diff --git a/msticpy/nbwidgets/get_environment_key.py b/msticpy/nbwidgets/get_environment_key.py index 8d7bea2c3..7e49bf486 100644 --- a/msticpy/nbwidgets/get_environment_key.py +++ b/msticpy/nbwidgets/get_environment_key.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + import os import ipywidgets as widgets @@ -85,9 +86,7 @@ def __init__( value=True, description="Save as environment var", disabled=False ) self._w_save_button.on_click(self._on_save_button_clicked) - self._hbox = widgets.HBox( - [self._w_text, self._w_save_button, self._w_check_save] - ) + self._hbox = widgets.HBox([self._w_text, self._w_save_button, self._w_check_save]) if auto_display: self.display() diff --git a/msticpy/nbwidgets/get_text.py b/msticpy/nbwidgets/get_text.py index 5e114d9cf..269ae51f1 100644 --- a/msticpy/nbwidgets/get_text.py +++ b/msticpy/nbwidgets/get_text.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" -from typing import Optional import ipywidgets as widgets from ipywidgets import Layout @@ -27,7 +26,7 @@ class GetText(RegisteredWidget, IPyDisplayMixin): def __init__( self, - default: Optional[str] = None, + default: str | None = None, description: str = "Enter the value: ", auto_display: bool = False, **kwargs, diff --git a/msticpy/nbwidgets/lookback.py b/msticpy/nbwidgets/lookback.py index 0f19e9cf9..840c6684e 100644 --- a/msticpy/nbwidgets/lookback.py +++ b/msticpy/nbwidgets/lookback.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + from datetime import datetime, timedelta, timezone -from typing import Optional import ipywidgets as widgets from ipywidgets import Layout @@ -28,11 +28,11 @@ class Lookback(IPyDisplayMixin): # pylint: disable=too-many-arguments def __init__( self, - default: Optional[int] = None, + default: int | None = None, description: str = "Select time ({units}) to look back", origin_time: datetime = None, - min_value: Optional[int] = None, - max_value: Optional[int] = None, + min_value: int | None = None, + max_value: int | None = None, units: str = "hour", auto_display: bool = False, **kwargs, @@ -64,9 +64,7 @@ def __init__( """ # default to now - self.origin_time = ( - datetime.now(timezone.utc) if origin_time is None else origin_time - ) + self.origin_time = datetime.now(timezone.utc) if origin_time is None else origin_time description = kwargs.pop("label", description) self._time_unit = parse_time_unit(units) diff --git a/msticpy/nbwidgets/option_buttons.py b/msticpy/nbwidgets/option_buttons.py index 65ce2123c..c56080976 100644 --- a/msticpy/nbwidgets/option_buttons.py +++ b/msticpy/nbwidgets/option_buttons.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + import asyncio -from typing import Any, Iterable, Optional +from collections.abc import Iterable +from typing import Any import ipywidgets as widgets from IPython.display import display @@ -43,9 +45,9 @@ class OptionButtons(IPyDisplayMixin): def __init__( self, - description: Optional[str] = "Select an option to continue", - buttons: Optional[Iterable[str]] = None, - default: Optional[str] = None, + description: str | None = "Select an option to continue", + buttons: Iterable[str] | None = None, + default: str | None = None, timeout: int = 0, debug: bool = False, **kwargs, @@ -79,7 +81,7 @@ def __init__( self._desc_label = widgets.Label(value=description) self._timer_label = widgets.Label(layout=widgets.Layout(left="10px")) self.default = default or next(iter(buttons)).casefold() - self.value: Optional[str] = None + self.value: str | None = None self.timeout = timeout self._completion: Any = None diff --git a/msticpy/nbwidgets/progress.py b/msticpy/nbwidgets/progress.py index 8d7205116..fe3fd3f3b 100644 --- a/msticpy/nbwidgets/progress.py +++ b/msticpy/nbwidgets/progress.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + import ipywidgets as widgets from .._version import VERSION diff --git a/msticpy/nbwidgets/query_time.py b/msticpy/nbwidgets/query_time.py index 10557f9d3..c68218f07 100644 --- a/msticpy/nbwidgets/query_time.py +++ b/msticpy/nbwidgets/query_time.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + from __future__ import annotations from datetime import datetime, timedelta, timezone @@ -302,16 +303,12 @@ def _get_time_parameters( ) -> None: """Process different init time parameters from kwargs.""" if timespan: - self._query_end = self.origin_time = self._ensure_timezone_aware( - timespan.end - ) + self._query_end = self.origin_time = self._ensure_timezone_aware(timespan.end) self._query_start = self._ensure_timezone_aware(timespan.start) elif start and end: timespan = TimeSpan(start=start, end=end) self._query_start = self._ensure_timezone_aware(timespan.start) - self._query_end = self.origin_time = self._ensure_timezone_aware( - timespan.end - ) + self._query_end = self.origin_time = self._ensure_timezone_aware(timespan.end) else: self.before = default_before_after(before, self._time_unit) self.after = default_before_after(after, self._time_unit) @@ -330,8 +327,7 @@ def _get_time_parameters( self.after = after if self.before == 0: self.before = before or int( - (self._query_end - self._query_start).total_seconds() - / self._time_unit.value, + (self._query_end - self._query_start).total_seconds() / self._time_unit.value, ) # Utility functions diff --git a/msticpy/nbwidgets/select_alert.py b/msticpy/nbwidgets/select_alert.py index 636df1c39..7d8f41c42 100644 --- a/msticpy/nbwidgets/select_alert.py +++ b/msticpy/nbwidgets/select_alert.py @@ -8,8 +8,9 @@ import contextlib import json import random +from collections.abc import Callable from json import JSONDecodeError -from typing import Any, Callable, List, Optional, Tuple +from typing import Any import ipywidgets as widgets import pandas as pd @@ -52,8 +53,8 @@ class SelectAlert(IPyDisplayMixin): def __init__( self, alerts: pd.DataFrame, - action: Callable[..., Optional[Tuple]] = None, - columns: List[str] = None, + action: Callable[..., tuple | None] = None, + columns: list[str] = None, auto_display: bool = False, id_col: str = "SystemAlertId", **kwargs, @@ -95,9 +96,7 @@ def __init__( columns = columns or ["AlertName", "ProductName"] self.disp_columns = list({col for col in columns if col in alerts.columns}) if not self.disp_columns: - raise ValueError( - f"Display columns {','.join(columns)} not found in alerts." - ) + raise ValueError(f"Display columns {','.join(columns)} not found in alerts.") self._select_items = self._get_select_options( alerts, self.time_col, self.id_col, self.disp_columns ) @@ -120,7 +119,7 @@ def __init__( # setup to use updatable display objects rand_id = random.randint(0, 999999) # nosec self._output_id = f"{self.__class__.__name__}_{rand_id}" - self._disp_elems: List[Any] = [] + self._disp_elems: list[Any] = [] # set up observer callbacks self._w_filter_alerts.observe(self._update_options, names="value") @@ -204,12 +203,10 @@ def _get_alert(self, alert_id): alert["ExtendedProperties"], str ): with contextlib.suppress(JSONDecodeError): - alert["ExtendedProperties"] = json.loads( - (alert["ExtendedProperties"]) - ) + alert["ExtendedProperties"] = json.loads(alert["ExtendedProperties"]) if "Entities" in alert.index and isinstance(alert["Entities"], str): with contextlib.suppress(JSONDecodeError): - alert["Entities"] = json.loads((alert["Entities"])) + alert["Entities"] = json.loads(alert["Entities"]) return alert return None @@ -217,9 +214,7 @@ def _select_top_alert(self): """Select the first alert by default.""" top_alert = self.alerts.iloc[0] if self.default_alert: - top_alert = self.alerts[ - self.alerts[self.id_col] == self.default_alert - ].iloc[0] + top_alert = self.alerts[self.alerts[self.id_col] == self.default_alert].iloc[0] if not top_alert.empty: self._w_select_alert.index = 0 self.alert_id = top_alert[self.id_col] @@ -236,7 +231,7 @@ def _run_action(self, change=None): if output_objs is None: self._clear_display() return - if not isinstance(output_objs, (tuple, list)): + if not isinstance(output_objs, tuple | list): output_objs = [output_objs] display_objs = bool(self._disp_elems) for idx, out_obj in enumerate(output_objs): @@ -258,9 +253,7 @@ def _clear_display(self): # pylint: disable=too-many-instance-attributes -@deprecated( - reason="Superceded by SelectAlert. Will be removed in v2.0.0.", version="0.5.2" -) +@deprecated(reason="Superceded by SelectAlert. Will be removed in v2.0.0.", version="0.5.2") class AlertSelector(SelectAlert): """ AlertSelector. @@ -287,7 +280,7 @@ def __init__( self, alerts: pd.DataFrame, action: Callable[..., None] = None, - columns: List[str] = None, + columns: list[str] = None, auto_display: bool = False, ): """ @@ -316,9 +309,7 @@ def __init__( def display(self): """Display the interactive widgets.""" self._select_top_alert() - display( - widgets.VBox([self._w_filter_alerts, self._w_select_alert, self._w_output]) - ) + display(widgets.VBox([self._w_filter_alerts, self._w_select_alert, self._w_output])) def _run_action(self, change=None): del change diff --git a/msticpy/nbwidgets/select_item.py b/msticpy/nbwidgets/select_item.py index 31124b2da..cbf4b76db 100644 --- a/msticpy/nbwidgets/select_item.py +++ b/msticpy/nbwidgets/select_item.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" + import random -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from collections.abc import Callable +from typing import Any import ipywidgets as widgets from deprecated.sphinx import deprecated @@ -34,8 +36,8 @@ class SelectItem(IPyDisplayMixin): def __init__( self, description: str = "Select an item", - options: Union[List[str], Dict[str, Any]] = None, - action: Callable[..., Optional[Tuple]] = None, + options: list[str] | dict[str, Any] | None = None, + action: Callable[..., tuple | None] | None = None, value: str = "", **kwargs, ): @@ -130,7 +132,7 @@ def __init__( # setup to use updatable display objects rand_id = random.randint(0, 999999) # nosec self._output_id = f"{self.__class__.__name__}_{rand_id}" - self._disp_elems: List[Any] = [] + self._disp_elems: list[Any] = [] if auto_display: self.display() @@ -182,9 +184,7 @@ def _filter_options(self, change): return self._wgt_select.options = self._get_filtered_options(change["new"]) - def _get_filtered_options( - self, substring: str = "" - ) -> List[Union[str, Tuple[str, str]]]: + def _get_filtered_options(self, substring: str = "") -> list[str | tuple[str, str]]: """Return optionally filtered list of option tuples.""" if self.options is None: return [] @@ -205,7 +205,7 @@ def _run_action(self, change=None): if output_objs is None: self._clear_display() return - if not isinstance(output_objs, (tuple, list)): + if not isinstance(output_objs, tuple | list): output_objs = [output_objs] display_objs = dict(enumerate(self._disp_elems)) for idx, out_obj in enumerate(output_objs): @@ -229,9 +229,7 @@ def _show_top_item(self): self._run_action() -@deprecated( - reason="Superceded by SelectItem. Will be removed in v2.0.0.", version="0.5.2" -) +@deprecated(reason="Superceded by SelectItem. Will be removed in v2.0.0.", version="0.5.2") class SelectString(SelectItem): """Selection list from list or dict.""" @@ -239,9 +237,9 @@ class SelectString(SelectItem): def __init__( self, description: str = "Select an item", - item_list: List[str] = None, + item_list: list[str] = None, action: Callable[..., None] = None, - item_dict: Dict[str, str] = None, + item_dict: dict[str, str] = None, auto_display: bool = False, height: str = "100px", width: str = "50%", diff --git a/msticpy/nbwidgets/select_subset.py b/msticpy/nbwidgets/select_subset.py index ba98f032c..1cdd6811b 100644 --- a/msticpy/nbwidgets/select_subset.py +++ b/msticpy/nbwidgets/select_subset.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Module for pre-defined widget layouts.""" -from typing import Any, Dict, List, Union + +from typing import Any import ipywidgets as widgets @@ -21,8 +22,8 @@ class SelectSubset(IPyDisplayMixin): def __init__( self, - source_items: Union[Dict[str, str], List[Any]], - default_selected: Union[Dict[str, str], List[Any]] = None, + source_items: dict[str, str] | list[Any], + default_selected: dict[str, str] | list[Any] | None = None, display_filter: bool = True, auto_display: bool = True, ): @@ -91,9 +92,7 @@ def __init__( self._b_del_all.on_click(self._on_btn_del_all) self._b_add_all.on_click(self._on_btn_add_all) - v_box = widgets.VBox( - [self._b_add_all, self._b_add, self._b_del, self._b_del_all] - ) + v_box = widgets.VBox([self._b_add_all, self._b_add, self._b_del, self._b_del_all]) self.layout = widgets.HBox([self._source_list, v_box, self._select_list]) if self._display_filter: self.layout = widgets.VBox([self._w_filter, self.layout]) @@ -101,12 +100,12 @@ def __init__( self.display() @property - def value(self) -> List[Any]: + def value(self) -> list[Any]: """Return currently selected value or values.""" return self.selected_values @property - def selected_items(self) -> List[Any]: + def selected_items(self) -> list[Any]: """ Return a list of the selected items. @@ -122,7 +121,7 @@ def selected_items(self) -> List[Any]: return list(self._select_list.options) @property - def selected_values(self) -> List[Any]: + def selected_values(self) -> list[Any]: """ Return list of selected values. @@ -135,9 +134,7 @@ def selected_values(self) -> List[Any]: List of selected item values. """ - if self._select_list.options and isinstance( - self._select_list.options[0], tuple - ): + if self._select_list.options and isinstance(self._select_list.options[0], tuple): return [item[1] for item in self._select_list.options] return self.selected_items @@ -145,11 +142,7 @@ def _update_options(self, change): """Filter the alert list by substring.""" if change is not None and "new" in change: self._source_list.options = sorted( - { - i - for i in self.src_items - if str(change["new"]).lower() in str(i).lower() - } + {i for i in self.src_items if str(change["new"]).lower() in str(i).lower()} ) # pylint: disable=not-an-iterable @@ -161,11 +154,11 @@ def _on_btn_add(self, button): selected_set.add(self._src_dict[selected]) else: selected_set.add(selected) - self._select_list.options = sorted(list(selected_set)) + self._select_list.options = sorted(selected_set) def _on_btn_add_all(self, button): del button - self._select_list.options = sorted(list(set(self._source_list.options))) + self._select_list.options = sorted(set(self._source_list.options)) def _on_btn_del(self, button): del button @@ -178,16 +171,16 @@ def _on_btn_del(self, button): selected_set.remove(self._src_dict[selected]) else: selected_set.remove(selected) - self._select_list.options = sorted(list(selected_set)) + self._select_list.options = sorted(selected_set) if not self._select_list.options: return # try to set the index to the next item in the list if cur_index < len(self._select_list.options): next_item = cur_index or 0 - self._select_list.index = tuple([next_item]) + self._select_list.index = (next_item,) else: last_item = max(len(self._select_list.options) - 1, 0) - self._select_list.index = tuple([last_item]) + self._select_list.index = (last_item,) # pylint: enable=not-an-iterable diff --git a/msticpy/sectools/__init__.py b/msticpy/sectools/__init__.py deleted file mode 100644 index 1e80bfc9f..000000000 --- a/msticpy/sectools/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -MSTICPy sectools. - -.. warning: This sub-package is deprecated. - All functionality has been removed from this sub-package and moved - to other sub-packages: - -- TI providers -> msticpy.context.tiproviders - (including vtlookup and vtlookupv3) -- auditdextract -> msticpy.transform -- base64unpack -> msticpy.transform -- cmd_line -> msticpy.context -- domain_utils -> msticpy.context -- eventcluster -> msticpy.analysis -- geoip -> msticpy.context -- iocextract -> msticpy.transform -- ip_utils -> msticpy.context -- proc_tree_builder -> msticpy.transform -- proc_tree_build_mde -> msticpy.transform -- proc_tree_build_winlx -> msticpy.transform -- proc_tree_schema -> msticpy.transform -- proc_tree_utils -> msticpy.transform -- sectools_magics -> msticpy.init.nbmagics -- syslog_utils -> msticpy.analysis - -The sectools sub-package will be removed in version 2.0.0 - -""" -from .._version import VERSION -from ..lazy_importer import lazy_import - -__version__ = VERSION - -_LAZY_IMPORTS = { - "msticpy.context.geoip.GeoLiteLookup", - "msticpy.context.geoip.IPStackLookup", - "msticpy.context.geoip.geo_distance", - "msticpy.context.tilookup.TILookup", - "msticpy.transform.base64unpack as base64", - "msticpy.transform.iocextract.IoCExtract", -} - -module, __getattr__, __dir__ = lazy_import(__name__, _LAZY_IMPORTS) diff --git a/msticpy/sectools/auditdextract.py b/msticpy/sectools/auditdextract.py deleted file mode 100644 index 6234acf0d..000000000 --- a/msticpy/sectools/auditdextract.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module auditdextract.py has moved. - -See :py:mod:`msticpy.transform.auditdextract` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.auditdextract import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.auditdextract\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/base64unpack.py b/msticpy/sectools/base64unpack.py deleted file mode 100644 index 3b439ee14..000000000 --- a/msticpy/sectools/base64unpack.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module base64unpack.py has moved. - -See :py:mod:`msticpy.transform.base64unpack` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.base64unpack import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.base64unpack\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/cmd_line.py b/msticpy/sectools/cmd_line.py deleted file mode 100644 index 4e4aaa1e8..000000000 --- a/msticpy/sectools/cmd_line.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module cmd_line.py has moved. - -See :py:mod:`msticpy.transform.cmd_line` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.cmd_line import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.cmd_line\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/domain_utils.py b/msticpy/sectools/domain_utils.py deleted file mode 100644 index 5cdc7a1df..000000000 --- a/msticpy/sectools/domain_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module domain_utils.py has moved. - -See :py:mod:`msticpy.context.domain_utils` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..context.domain_utils import * - -WARN_MSSG = ( - "This module has moved to msticpy.analysis.domain_utils\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/eventcluster.py b/msticpy/sectools/eventcluster.py deleted file mode 100644 index 066a1bd1d..000000000 --- a/msticpy/sectools/eventcluster.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module eventcluster.py has moved. - -See :py:mod:`msticpy.analysis.eventcluster` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Ian Hellen" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..analysis.eventcluster import * - -WARN_MSSG = ( - "This module has moved to msticpy.analysis.eventcluster\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/geoip.py b/msticpy/sectools/geoip.py deleted file mode 100644 index 938cfd6e7..000000000 --- a/msticpy/sectools/geoip.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module geoip.py has moved. - -See :py:mod:`msticpy.context.geoip` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..context.geoip import * - -WARN_MSSG = ( - "This module has moved to msticpy.analysis.geoip\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/iocextract.py b/msticpy/sectools/iocextract.py deleted file mode 100644 index addaaccc5..000000000 --- a/msticpy/sectools/iocextract.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module iocextract.py has moved. - -See :py:mod:`msticpy.transform.iocextract` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.iocextract import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.iocextract\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/ip_utils.py b/msticpy/sectools/ip_utils.py deleted file mode 100644 index 9f9e89ba3..000000000 --- a/msticpy/sectools/ip_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module ip_utils.py has moved. - -See :py:mod:`msticpy.context.ip_utils` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..context.ip_utils import * - -WARN_MSSG = ( - "This module has moved to msticpy.analysis.ip_utils\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/proc_tree_build_mde.py b/msticpy/sectools/proc_tree_build_mde.py deleted file mode 100644 index 35b884ed9..000000000 --- a/msticpy/sectools/proc_tree_build_mde.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module proc_tree_build_mde.py has moved. - -See :py:mod:`msticpy.transform.proc_tree_build_mde` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.proc_tree_build_mde import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.proc_tree_build_mde\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/proc_tree_build_winlx.py b/msticpy/sectools/proc_tree_build_winlx.py deleted file mode 100644 index 220e90cf4..000000000 --- a/msticpy/sectools/proc_tree_build_winlx.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module proc_tree_build_winlx.py has moved. - -See :py:mod:`msticpy.transform.proc_tree_build_winlx` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.proc_tree_build_winlx import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.proc_tree_build_winlx\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/proc_tree_builder.py b/msticpy/sectools/proc_tree_builder.py deleted file mode 100644 index 8dd85d3e0..000000000 --- a/msticpy/sectools/proc_tree_builder.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module proc_tree_builder.py has moved. - -See :py:mod:`msticpy.transform.proc_tree_builder` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.proc_tree_builder import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.proc_tree_builder\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/proc_tree_schema.py b/msticpy/sectools/proc_tree_schema.py deleted file mode 100644 index 3ecbc94db..000000000 --- a/msticpy/sectools/proc_tree_schema.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module proc_tree_schema.py has moved. - -See :py:mod:`msticpy.transform.proc_tree_schema` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.proc_tree_schema import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.proc_tree_schema\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/proc_tree_utils.py b/msticpy/sectools/proc_tree_utils.py deleted file mode 100644 index 3422845fb..000000000 --- a/msticpy/sectools/proc_tree_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module process_tree_utils.py has moved. - -See :py:mod:`msticpy.transform.process_tree_utils` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..transform.process_tree_utils import * - -WARN_MSSG = ( - "This module has moved to msticpy.transform.process_tree_utils\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/sectools_magics.py b/msticpy/sectools/sectools_magics.py deleted file mode 100644 index e3733d3d1..000000000 --- a/msticpy/sectools/sectools_magics.py +++ /dev/null @@ -1,23 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module sectools_magics.py has moved. - -See :py:mod:`msticpy.init.nb_magics` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - -WARN_MSSG = ( - "This module has moved to msticpy.init.nb_magics\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/syslog_utils.py b/msticpy/sectools/syslog_utils.py deleted file mode 100644 index 4d92b43f1..000000000 --- a/msticpy/sectools/syslog_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module syslog_utils.py has moved. - -See :py:mod:`msticpy.analysis.syslog_utils` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..analysis.syslog_utils import * - -WARN_MSSG = ( - "This module has moved to msticpy.analysis.syslog_utils\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/tilookup.py b/msticpy/sectools/tilookup.py deleted file mode 100644 index fbf860e8f..000000000 --- a/msticpy/sectools/tilookup.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module tilookup.py has moved. - -See :py:mod:`msticpy.context.tilookup` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ..context.tilookup import * - -WARN_MSSG = ( - "This module has moved to msticpy.context.tilookup\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/tiproviders/__init__.py b/msticpy/sectools/tiproviders/__init__.py deleted file mode 100644 index f600d0384..000000000 --- a/msticpy/sectools/tiproviders/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Deprecated location for TI Providers.""" - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Ian Hellen" diff --git a/msticpy/sectools/tiproviders/ti_provider_base.py b/msticpy/sectools/tiproviders/ti_provider_base.py deleted file mode 100644 index 3010c80bb..000000000 --- a/msticpy/sectools/tiproviders/ti_provider_base.py +++ /dev/null @@ -1,29 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module ti_provider_base.py has moved. - -See :py:mod:`msticpy.context.tiproviders.ti_provider_base` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ...context.tiproviders.result_severity import ResultSeverity as TISeverity -from ...context.tiproviders.ti_provider_base import * - -WARN_MSSG = ( - "This module has moved to " - "msticpy.context.tiproviders.ti_provider_base\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/vtlookup.py b/msticpy/sectools/vtlookup.py deleted file mode 100644 index e628033b0..000000000 --- a/msticpy/sectools/vtlookup.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module vtlookup.py has moved. - -See :py:mod:`msticpy.context.vtlookupv3.vtlookup` -""" -import warnings - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=unused-import, unused-wildcard-import, wildcard-import -from ..context.vtlookupv3.vtlookup import * - -WARN_MSSG = ( - "This module has moved to msticpy.context.vtlookupv3.vtlookup\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/vtlookupv3/__init__.py b/msticpy/sectools/vtlookupv3/__init__.py deleted file mode 100644 index 52ecdcdce..000000000 --- a/msticpy/sectools/vtlookupv3/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""VirusTotal V3 Subpackage.""" diff --git a/msticpy/sectools/vtlookupv3/vtfile_behavior.py b/msticpy/sectools/vtlookupv3/vtfile_behavior.py deleted file mode 100644 index c85985dd2..000000000 --- a/msticpy/sectools/vtlookupv3/vtfile_behavior.py +++ /dev/null @@ -1,29 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module vtfile_behavior.py has moved. - -See :py:mod:`msticpy.context.vtlookupv3.vtfile_behavior` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ...context.vtlookupv3.vtfile_behavior import * - -WARN_MSSG = ( - "This module has moved to " - "msticpy.context.vtlookupv3.vtfile_behavior\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/vtlookupv3/vtlookupv3.py b/msticpy/sectools/vtlookupv3/vtlookupv3.py deleted file mode 100644 index 84bae28a7..000000000 --- a/msticpy/sectools/vtlookupv3/vtlookupv3.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module vtlookupv3.py has moved. - -See :py:mod:`msticpy.context.vtlookupv3.vtlookupv3` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ...context.vtlookupv3.vtlookupv3 import * - -WARN_MSSG = ( - "This module has moved to msticpy.context.vtlookupv3.vtlookupv3\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/sectools/vtlookupv3/vtobject_browser.py b/msticpy/sectools/vtlookupv3/vtobject_browser.py deleted file mode 100644 index 90cacfa8a..000000000 --- a/msticpy/sectools/vtlookupv3/vtobject_browser.py +++ /dev/null @@ -1,29 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Deprecated - module vtobject_browser.py has moved. - -See :py:mod:`msticpy.vis.vtobject_browser` -""" -import warnings - -from ..._version import VERSION - -__version__ = VERSION -__author__ = "Pete Bryan" - - -# flake8: noqa: F403, F401 -# pylint: disable=wildcard-import, unused-wildcard-import, unused-import -from ...vis.vtobject_browser import * - -WARN_MSSG = ( - "This module has moved to " - "msticpy.context.vtlookupv3.vtobject_browser\n" - "Please change your import to reflect this new location." - "This will be removed in MSTICPy v2.2.0" -) -warnings.warn(WARN_MSSG, category=DeprecationWarning) diff --git a/msticpy/transform/auditdextract.py b/msticpy/transform/auditdextract.py index f8922d3a2..252b3796a 100644 --- a/msticpy/transform/auditdextract.py +++ b/msticpy/transform/auditdextract.py @@ -13,10 +13,12 @@ line arguments into a single string). This is still a work-in-progress. """ + import codecs import re +from collections.abc import Mapping from datetime import datetime, timezone -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple +from typing import Any import pandas as pd @@ -25,7 +27,7 @@ try: # pylint: disable=unused-import - from ..analysis import cluster_auditd # type: ignore + from ..analysis import cluster_auditd except ImportError: def cluster_auditd(*args, **kwargs): # type: ignore @@ -40,14 +42,14 @@ def cluster_auditd(*args, **kwargs): # type: ignore # Constants # Fields that we know are frequently encoded -_ENCODED_PARAMS: Dict[str, Set[str]] = { +_ENCODED_PARAMS: dict[str, set[str]] = { "EXECVE": {"a0", "a1", "a2", "a3", "arch"}, "PROCTITLE": {"proctitle"}, "USER_CMD": {"cmd"}, } # USER_START message schema -_USER_START: Dict[str, Optional[str]] = { +_USER_START: dict[str, str | None] = { "pid": "int", "uid": "int", "auid": "int", @@ -62,7 +64,7 @@ def cluster_auditd(*args, **kwargs): # type: ignore } # Message types schema -_FIELD_DEFS: Dict[str, Dict[str, Optional[str]]] = { +_FIELD_DEFS: dict[str, dict[str, str | None]] = { "SYSCALL": { "success": None, "ppid": "int", @@ -106,7 +108,7 @@ def cluster_auditd(*args, **kwargs): # type: ignore @export -def unpack_auditd(audit_str: List[Dict[str, str]]) -> Mapping[str, Mapping[str, Any]]: +def unpack_auditd(audit_str: list[dict[str, str]]) -> Mapping[str, Mapping[str, Any]]: """ Unpack an Audit message and returns a dictionary of fields. @@ -121,7 +123,7 @@ def unpack_auditd(audit_str: List[Dict[str, str]]) -> Mapping[str, Mapping[str, The extracted message fields and values """ - event_dict: Dict[str, Dict[str, Any]] = {} + event_dict: dict[str, dict[str, Any]] = {} # The audit_str should be a list of dicts - '{EXECVE : {'p1': 'foo', p2: 'bar'...}, # PATH: {'a1': 'xyz',....}} @@ -129,7 +131,7 @@ def unpack_auditd(audit_str: List[Dict[str, str]]) -> Mapping[str, Mapping[str, # process a single message type, splitting into type name # and contents for rec_key, rec_val in record.items(): - rec_dict: Dict[str, Optional[str]] = {} + rec_dict: dict[str, str | None] = {} # Get our field mapping for encoded params for this # mssg_type (rec_key) encoded_fields_map = _ENCODED_PARAMS.get(rec_key, None) @@ -151,7 +153,7 @@ def unpack_auditd(audit_str: List[Dict[str, str]]) -> Mapping[str, Mapping[str, # Mypy thinks codecs.decode returns a str so # incorrectly issues a type warning - in this case it # will return a bytes string. - field_value = codecs.decode( # type: ignore + field_value = codecs.decode( bytes(rec_split[1], "utf-8"), "hex" ).decode("utf-8") except ValueError: @@ -170,7 +172,7 @@ def unpack_auditd(audit_str: List[Dict[str, str]]) -> Mapping[str, Mapping[str, return event_dict -def _extract_event(message_dict: Mapping[str, Any]) -> Tuple[str, Mapping[str, Any]]: +def _extract_event(message_dict: Mapping[str, Any]) -> tuple[str, Mapping[str, Any]]: """ Assemble discrete messages sharing the same message Id into a single event. @@ -187,7 +189,7 @@ def _extract_event(message_dict: Mapping[str, Any]) -> Tuple[str, Mapping[str, A """ # Handle process executions specially if "SYSCALL" in message_dict and "EXECVE" in message_dict: - proc_create_dict: Dict[str, Any] = {} + proc_create_dict: dict[str, Any] = {} for mssg_type in ["SYSCALL", "CWD", "EXECVE", "PROCTITLE"]: if mssg_type not in message_dict or mssg_type not in _FIELD_DEFS: continue @@ -195,13 +197,11 @@ def _extract_event(message_dict: Mapping[str, Any]) -> Tuple[str, Mapping[str, A if mssg_type == "EXECVE": args = int(proc_create_dict.get("argc", 1)) - arg_strs = [ - proc_create_dict.get(f"a{arg_idx}", "") for arg_idx in range(args) - ] + arg_strs = [proc_create_dict.get(f"a{arg_idx}", "") for arg_idx in range(args)] proc_create_dict["cmdline"] = " ".join(arg_strs) return "SYSCALL_EXECVE", proc_create_dict - event_dict: Dict[str, Any] = {} + event_dict: dict[str, Any] = {} for mssg_type, _ in message_dict.items(): if mssg_type in _FIELD_DEFS: _extract_mssg_value(mssg_type, message_dict, event_dict) @@ -216,7 +216,7 @@ def _extract_event(message_dict: Mapping[str, Any]) -> Tuple[str, Mapping[str, A def _extract_mssg_value( mssg_type: str, message_dict: Mapping[str, Mapping[str, Any]], - event_dict: Dict[str, Any], + event_dict: dict[str, Any], ): """ Extract field/value from the message dictionary. @@ -302,9 +302,7 @@ def extract_events_to_df( # If the provided table has auditd messages as a string format and # extract key elements. if isinstance(data[input_column].head(1)[0], str): - data["mssg_id"] = data.apply( - lambda x: _extract_timestamp(x[input_column]), axis=1 - ) + data["mssg_id"] = data.apply(lambda x: _extract_timestamp(x[input_column]), axis=1) data[input_column] = data.apply( lambda x: _parse_audit_message(x[input_column]), axis=1 ) @@ -349,9 +347,7 @@ def extract_events_to_df( # extract real timestamp from mssg_id tmp_df["TimeStamp"] = tmp_df.apply( - lambda x: datetime.fromtimestamp( - float(x["mssg_id"].split(":")[0]), tz=timezone.utc - ), + lambda x: datetime.fromtimestamp(float(x["mssg_id"].split(":")[0]), tz=timezone.utc), axis=1, ) if "TimeGenerated" in tmp_df: @@ -386,9 +382,7 @@ def get_event_subset(data: pd.DataFrame, event_type: str) -> pd.DataFrame: data['EventType'] == event_type """ - return ( - data[data["EventType"] == event_type].dropna(axis=1, how="all").infer_objects() - ) + return data[data["EventType"] == event_type].dropna(axis=1, how="all").infer_objects() @export @@ -433,9 +427,7 @@ def read_from_file( ) # extract message ID into separate column - df_raw["mssg_id"] = df_raw.apply( - lambda x: _extract_timestamp(x["raw_data"]), axis=1 - ) + df_raw["mssg_id"] = df_raw.apply(lambda x: _extract_timestamp(x["raw_data"]), axis=1) # pylint: disable=unsupported-assignment-operation, no-member # Pack message type and content into a dictionary: # {'mssg_type: ['item1=x, item2=y....]} @@ -445,9 +437,7 @@ def read_from_file( # Group the data by message id string and concatenate the message content # dictionaries in a list. - df_grouped_cols = ( - df_raw.groupby(["mssg_id"]).agg({"AuditdMessage": list}).reset_index() - ) + df_grouped_cols = df_raw.groupby(["mssg_id"]).agg({"AuditdMessage": list}).reset_index() # pylint: enable=unsupported-assignment-operation, no-member # pass this DataFrame to the event extractor. @@ -459,7 +449,7 @@ def read_from_file( ) -def _parse_audit_message(audit_str: str) -> Dict[str, List[str]]: +def _parse_audit_message(audit_str: str) -> dict[str, list[str]]: """ Parse an auditd message string into Dict format required by unpack_auditd. @@ -507,7 +497,7 @@ def _extract_timestamp(audit_str: str) -> str: # pylint: disable=too-many-branches @export -def generate_process_tree( # noqa: MC0001 +def generate_process_tree( audit_data: pd.DataFrame, branch_depth: int = 4, processes: pd.DataFrame = None ) -> pd.DataFrame: """ diff --git a/msticpy/transform/base64unpack.py b/msticpy/transform/base64unpack.py index 11d64ff8f..69f04fb9a 100644 --- a/msticpy/transform/base64unpack.py +++ b/msticpy/transform/base64unpack.py @@ -31,12 +31,12 @@ import io import re import tarfile -import warnings import zipfile from collections import namedtuple +from collections.abc import Callable, Iterable # pylint: disable=unused-import -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any import pandas as pd @@ -92,7 +92,7 @@ # we use this to store a set of strings that match the B64 regex but # that we were unable to decode - so that we don't end up in an # infinite loop -_UNDECODABLE_STRINGS: Set[str] = set() +_UNDECODABLE_STRINGS: set[str] = set() # When True prints see more verbose execution # (set from 'trace' parameter to unpack_items) @@ -103,11 +103,11 @@ _STRIP_TAGS = r"]*>" -def _get_trace_setting() -> Callable[[Optional[bool]], bool]: +def _get_trace_setting() -> Callable[[bool | None], bool]: """Closure for holding trace setting.""" _trace = False - def _trace_enabled(trace: Optional[bool] = None) -> bool: + def _trace_enabled(trace: bool | None = None) -> bool: nonlocal _trace if trace is not None: _trace = trace @@ -120,11 +120,11 @@ def _trace_enabled(trace: Optional[bool] = None) -> bool: GET_TRACE = _get_trace_setting() -def _get_utf16_setting() -> Callable[[Optional[bool]], bool]: +def _get_utf16_setting() -> Callable[[bool | None], bool]: """Closure for holding utf16 decoding setting.""" _utf16 = False - def _utf16_enabled(utf16: Optional[bool] = None) -> bool: + def _utf16_enabled(utf16: bool | None = None) -> bool: nonlocal _utf16 if utf16 is not None: _utf16 = utf16 @@ -221,7 +221,7 @@ def unpack_items( @export def unpack( input_string: str, trace: bool = False, utf16: bool = False -) -> Tuple[str, pd.DataFrame]: +) -> tuple[str, pd.DataFrame]: """ Base64 decode an input string. @@ -318,7 +318,7 @@ def unpack_df( GET_UTF16(utf16) output_df = pd.DataFrame(columns=BinaryRecord._fields) - row_results: List[pd.DataFrame] = [] + row_results: list[pd.DataFrame] = [] rows_with_b64_match = data[data[column].str.contains(_BASE64_REGEX_NG)] for input_row in rows_with_b64_match[[column]].itertuples(): (decoded_string, output_frame) = _decode_b64_string_recursive(input_row[1]) @@ -338,7 +338,7 @@ def _decode_b64_string_recursive( max_recursion: int = 20, current_depth: int = 1, item_prefix: str = "", -) -> Tuple[str, pd.DataFrame]: +) -> tuple[str, pd.DataFrame]: """Recursively decode and unpack an encoded string.""" _debug_print_trace("_decode_b64_string_recursive: ", max_recursion) _debug_print_trace("processing input: ", input_string[:200]) @@ -418,9 +418,7 @@ def _decode_b64_string_recursive( if decode_success: # stuff that we have already decoded may also contain further # base64 encoded strings - prefix = ( - f"{item_prefix}.{fragment_index}." if item_prefix else f"{fragment_index}." - ) + prefix = f"{item_prefix}.{fragment_index}." if item_prefix else f"{fragment_index}." next_level_string, child_records = _decode_b64_string_recursive( decoded_string, item_prefix=prefix, @@ -442,7 +440,7 @@ def _add_to_results( current_depth: int, item_prefix: str, fragment_index: int, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Add current set of decoding results to collection.""" new_rows = [] for bin_record in binary_items: @@ -474,7 +472,7 @@ def _decode_and_format_b64_string( item_prefix: str = "", current_depth: int = 1, current_index: int = 1, -) -> Tuple[str, Optional[List[BinaryRecord]]]: +) -> tuple[str, list[BinaryRecord] | None]: """Decode string and return displayable content plus list of decoded artifacts.""" # Check if we recognize this as a known file type (_, f_type) = _is_known_b64_prefix(b64encoded_string) @@ -505,9 +503,7 @@ def _decode_and_format_b64_string( _debug_print_trace("_decode_b64_binary returned multiple records") # Build child display strings - for child_index, (child_name, child_rec) in enumerate( - output_files.items(), start=1 - ): + for child_index, (child_name, child_rec) in enumerate(output_files.items(), start=1): _debug_print_trace("Child_decode: ", child_rec) child_index_string = f"{item_prefix}{current_index}.{child_index}" disp_string = _format_single_record( @@ -617,7 +613,7 @@ def _get_byte_encoding(bytes_array: bytes) -> BinaryRecord: def _is_known_b64_prefix( input_string: str, -) -> Union[Tuple[str, str], Tuple[None, None]]: +) -> tuple[str, str] | tuple[None, None]: """If this is known file type return the prefix and file type.""" first160chars = input_string[0:160].replace("\n", "").replace("\r", "") for prefix, file_type in _BASE64_HEADER_TYPES.items(): @@ -633,7 +629,7 @@ def _is_known_b64_prefix( def _decode_b64_binary( input_string: str, file_type: str = None -) -> Optional[Dict[str, BinaryRecord]]: +) -> dict[str, BinaryRecord] | None: """Examine input string for known binaries and decode and unpack.""" if not file_type: (_, f_type) = _is_known_b64_prefix(input_string) @@ -651,7 +647,7 @@ def _decode_b64_binary( def _unpack_and_hash_b64_binary( input_bytes: bytes, file_type: str = None -) -> Optional[Dict[str, BinaryRecord]]: +) -> dict[str, BinaryRecord] | None: """ If this is a known archive type extract the contents. @@ -714,7 +710,7 @@ def _get_hashes_and_printable_string(extracted_file: bytes) -> BinaryRecord: def _get_items_from_archive( binary: bytes, archive_type: str = "zip" -) -> Tuple[str, Dict[str, bytes]]: +) -> tuple[str, dict[str, bytes]]: """Extract contained files from an archive type.""" _debug_print_trace("_get_items_from_archive type: ", archive_type) if archive_type == "zip": @@ -727,7 +723,7 @@ def _get_items_from_archive( @export -def get_items_from_gzip(binary: bytes) -> Tuple[str, Dict[str, bytes]]: +def get_items_from_gzip(binary: bytes) -> tuple[str, dict[str, bytes]]: """ Return decompressed gzip contents. @@ -747,7 +743,7 @@ def get_items_from_gzip(binary: bytes) -> Tuple[str, Dict[str, bytes]]: @export -def get_items_from_zip(binary: bytes) -> Tuple[str, Dict[str, bytes]]: +def get_items_from_zip(binary: bytes) -> tuple[str, dict[str, bytes]]: """ Return dictionary of zip contents. @@ -772,7 +768,7 @@ def get_items_from_zip(binary: bytes) -> Tuple[str, Dict[str, bytes]]: @export -def get_items_from_tar(binary: bytes) -> Tuple[str, Dict[str, bytes]]: +def get_items_from_tar(binary: bytes) -> tuple[str, dict[str, bytes]]: """ Return dictionary of tar file contents. @@ -790,7 +786,7 @@ def get_items_from_tar(binary: bytes) -> Tuple[str, Dict[str, bytes]]: file_obj = io.BytesIO(binary) # Open tarfile with tarfile.open(mode="r", fileobj=file_obj) as tar: - archive_dict: Dict[str, bytes] = {} + archive_dict: dict[str, bytes] = {} # Iterate over every member for item in tar.getnames(): tar_file = tar.extractfile(item) @@ -799,7 +795,7 @@ def get_items_from_tar(binary: bytes) -> Tuple[str, Dict[str, bytes]]: @export -def get_hashes(binary: bytes) -> Dict[str, str]: +def get_hashes(binary: bytes) -> dict[str, str]: """ Return md5, sha1 and sha256 hashes of input byte string. @@ -831,7 +827,7 @@ def get_hashes(binary: bytes) -> Dict[str, str]: return hash_dict -def _binary_to_bytesio(binary: Union[bytes, io.BytesIO]) -> memoryview: +def _binary_to_bytesio(binary: bytes | io.BytesIO) -> memoryview: if isinstance(binary, io.BytesIO): return binary.getbuffer() return io.BytesIO(binary).getbuffer() @@ -846,70 +842,4 @@ def _b64_string_pad(string: str) -> str: return f"{string}{'A' * padding}" -# pylint: disable=too-few-public-methods -@pd.api.extensions.register_dataframe_accessor("mp_b64") -class B64ExtractAccessor: - """Base64 Unpack pandas extension.""" - - def __init__(self, pandas_obj): - """Initialize the extension.""" - self._df = pandas_obj - - def extract(self, column, **kwargs) -> pd.DataFrame: - """ - Base64 decode strings taken from a pandas dataframe. - - Parameters - ---------- - data : pd.DataFrame - dataframe containing column to decode - column : str - Name of dataframe text column - trace : bool, optional - Show additional status (the default is None) - utf16 : bool, optional - Attempt to decode UTF16 byte strings - - Returns - ------- - pd.DataFrame - Decoded string and additional metadata in dataframe - - Notes - ----- - Items that decode to utf-8 or utf-16 strings will be returned as decoded - strings replaced in the original string. If the encoded string is a - known binary type it will identify the file type and return the hashes - of the file. If any binary types are known archives (zip, tar, gzip) it - will unpack the contents of the archive. - For any binary it will return the decoded file as a byte array, and as a - printable list of byte values. - - The columns of the output DataFrame are: - - - decoded string: this is the input string with any decoded sections - replaced by the results of the decoding - - reference : this is an index that matches an index number in the - decoded string (e.g. < dict: """ @@ -73,9 +73,7 @@ def risky_cmd_line( """ if cmd_field not in events.columns: - raise MsticpyException( - f"The provided dataset does not contain the {cmd_field} field" - ) + raise MsticpyException(f"The provided dataset does not contain the {cmd_field} field") if detection_rules is None: detection_rules = str( Path(__file__) @@ -85,12 +83,9 @@ def risky_cmd_line( events[cmd_field] = events[cmd_field].replace("", np.nan) activity = ( - events[["TimeGenerated", cmd_field]] - .dropna() - .set_index("TimeGenerated") - .to_dict() + events[["TimeGenerated", cmd_field]].dropna().set_index("TimeGenerated").to_dict() ) - with open(detection_rules, "r", encoding="utf-8") as json_file: + with open(detection_rules, encoding="utf-8") as json_file: rules = json.load(json_file) # Decode any Base64 encoded commands so we can match on them as well @@ -105,16 +100,15 @@ def risky_cmd_line( if b64_regex.match(message): b64match = b64_regex.search(message) b64string = unpack(input_string=b64match[1]) # type: ignore - b64string = b64string[1]["decoded_string"].to_string() # type: ignore + b64string = b64string[1]["decoded_string"].to_string() if re.match(detection, message): risky_actions.update({date: message}) else: pass + elif re.match(detection, message): + risky_actions.update({date: message}) else: - if re.match(detection, message): - risky_actions.update({date: message}) - else: - pass + pass return risky_actions @@ -162,10 +156,7 @@ def cmd_speed( actions = cmd_events.dropna(subset=[cmd_field]).reset_index() df_len = len(actions.index) - (events + 1) while df_len >= 0: - delta = ( - actions["TimeGenerated"][(df_len + events)] - - actions["TimeGenerated"][df_len] - ) + delta = actions["TimeGenerated"][(df_len + events)] - actions["TimeGenerated"][df_len] if delta < dt.timedelta(seconds=time): suspicious_actions.append( {df_len: [actions[df_len : (df_len + events)], delta]} # noqa: E203 diff --git a/msticpy/transform/iocextract.py b/msticpy/transform/iocextract.py index 777a6b32d..bf61d1387 100644 --- a/msticpy/transform/iocextract.py +++ b/msticpy/transform/iocextract.py @@ -22,10 +22,10 @@ regular expressions used at runtime. """ + from __future__ import annotations import re -import warnings from collections import defaultdict from enum import Enum from typing import Any @@ -241,7 +241,7 @@ def __init__(self: IoCExtract, defanged: bool = True) -> None: # inline import due to circular dependency # pylint: disable=import-outside-toplevel - from ..context.domain_utils import DomainValidator + from ..context.domain_utils import DomainValidator # noqa: PLC0415 # pylint: enable=import-outside-toplevel self._dom_validator = DomainValidator() @@ -646,11 +646,7 @@ def get_ioc_type(self, observable: str) -> str: return IoCType.unknown.name return next( - ( - ioc_type - for ioc_type, match_set in results.items() - if observable in match_set - ), + (ioc_type for ioc_type, match_set in results.items() if observable in match_set), IoCType.unknown.name, ) @@ -724,9 +720,7 @@ def _check_decode_url(self, match_str, rgx_def, match_pos, iocs_found): ) @staticmethod - def _add_highest_pri_match( - iocs_found: dict, current_match: str, current_def: IoCPattern - ): + def _add_highest_pri_match(iocs_found: dict, current_match: str, current_def: IoCPattern): # if we already found a match for this item and the previous # ioc type is more specific then don't add this to the results if ( @@ -736,65 +730,3 @@ def _add_highest_pri_match( return iocs_found[current_match] = (current_def.ioc_type, current_def.priority) - - -# pylint: disable=too-few-public-methods -@pd.api.extensions.register_dataframe_accessor("mp_ioc") -class IoCExtractAccessor: - """Pandas api extension for IoC Extractor.""" - - def __init__(self, pandas_obj): - """Instantiate pandas extension class.""" - self._df = pandas_obj - self._ioc = IoCExtract() - - def extract(self, columns, **kwargs): - """ - Extract IoCs from either a pandas DataFrame. - - Parameters - ---------- - columns : list - The list of columns to use as source strings, - - Other Parameters - ---------------- - ioc_types : list, optional - Restrict matching to just specified types. - (default is all types) - include_paths : bool, optional - Whether to include path matches (which can be noisy) - (the default is false - excludes 'windows_path' - and 'linux_path'). If `ioc_types` is specified - this parameter is ignored. - - Returns - ------- - pd.DataFrame - DataFrame of observables - - Notes - ----- - Extract takes a pandas DataFrame as input. - The results will be returned as a new - DataFrame with the following columns: - - IoCType: the mnemonic used to distinguish different IoC Types - - Observable: the actual value of the observable - - SourceIndex: the index of the row in the input DataFrame from - which the source for the IoC observable was extracted. - - IoCType Pattern selection - The default list is: ['ipv4', 'ipv6', 'dns', 'url', - 'md5_hash', 'sha1_hash', 'sha256_hash'] plus any - user-defined types. - 'windows_path', 'linux_path' are excluded unless `include_paths` - is True or explicitly included in `ioc_paths`. - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp.ioc_extract() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return self._ioc.extract_df(data=self._df, columns=columns, **kwargs) diff --git a/msticpy/transform/network.py b/msticpy/transform/network.py index 6a84e1f12..3cd79a985 100644 --- a/msticpy/transform/network.py +++ b/msticpy/transform/network.py @@ -4,11 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Module for converting DataFrame to Networkx graph.""" -from typing import Callable, Dict, Iterable, Optional, Union + +from collections.abc import Callable, Iterable +from typing import Literal import networkx as nx import pandas as pd -from typing_extensions import Literal from .._version import VERSION @@ -24,9 +25,9 @@ def df_to_networkx( data: pd.DataFrame, source_col: str, target_col: str, - source_attrs: Optional[Iterable[str]] = None, - target_attrs: Optional[Iterable[str]] = None, - edge_attrs: Optional[Iterable[str]] = None, + source_attrs: Iterable[str] | None = None, + target_attrs: Iterable[str] | None = None, + edge_attrs: Iterable[str] | None = None, graph_type: GraphType = "graph", ): """ @@ -56,9 +57,7 @@ def df_to_networkx( """ create_as = nx.DiGraph if graph_type == "digraph" else nx.Graph - _verify_columns( - data, source_col, target_col, source_attrs, target_attrs, edge_attrs - ) + _verify_columns(data, source_col, target_col, source_attrs, target_attrs, edge_attrs) # remove any source or target rows that are NaN data = data.dropna(axis=0, subset=[source_col, target_col]) nx_graph = nx.from_pandas_edgelist( @@ -78,14 +77,14 @@ def _set_node_attributes( data: pd.DataFrame, graph: nx.Graph, column: str, - attrib_cols: Optional[Iterable[str]], + attrib_cols: Iterable[str] | None, node_role: NodeRole, ): """Set node attributes from column values.""" all_cols = [column, *attrib_cols] if attrib_cols else [column] # Create an 'agg' dictionary to apply to DataFrame - agg_dict: Dict[str, Union[str, Callable]] = ( - {col: _pd_unique_list for col in attrib_cols} if attrib_cols else {} + agg_dict: dict[str, str | Callable] = ( + dict.fromkeys(attrib_cols, _pd_unique_list) if attrib_cols else {} ) # Add these two items as attributes agg_dict.update({"node_role": "first", "node_type": "first"}) @@ -112,9 +111,7 @@ def _pd_unique_list(series: pd.Series): return ", ".join([str(attrib) for attrib in unique_vals]) -def _verify_columns( - data, source_col, target_col, source_attrs, target_attrs, edge_attrs -): +def _verify_columns(data, source_col, target_col, source_attrs, target_attrs, edge_attrs): """Check specified columns are in data.""" missing_columns = { **_verify_column(data, "source_col", source_col), diff --git a/msticpy/transform/proc_tree_build_mde.py b/msticpy/transform/proc_tree_build_mde.py index 815d6af66..6a858e6fa 100644 --- a/msticpy/transform/proc_tree_build_mde.py +++ b/msticpy/transform/proc_tree_build_mde.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- """Process tree builder routines for MDE process data.""" -from typing import Dict, Tuple, Union import numpy as np import pandas as pd @@ -110,7 +109,7 @@ def _add_proc_key( def _extract_missing_parents( - data: pd.DataFrame, col_mapping: Dict[str, str], debug: bool = False + data: pd.DataFrame, col_mapping: dict[str, str], debug: bool = False ) -> pd.DataFrame: """Return parent processes that are not in the created process set.""" # save the source index @@ -144,10 +143,8 @@ def _extract_missing_parents( # print(non_par_cols) # merge the original data with the parent rows - merged_parents = data.filter( - regex="Initiating.*|parent_key|src_index" - ).merge( # parents - data.filter(non_par_cols), # type: ignore + merged_parents = data.filter(regex="Initiating.*|parent_key|src_index").merge( # parents + data.filter(non_par_cols), left_on=Col.parent_key, right_on=Col.proc_key, suffixes=("_child", "_par"), @@ -164,9 +161,7 @@ def _extract_missing_parents( .drop(columns=["InitiatingProcessFileName"]) ) missing_parents["CreatedProcessFilePath"] = ( - missing_parents.CreatedProcessFilePath - + "\\" - + missing_parents.CreatedProcessName + missing_parents.CreatedProcessFilePath + "\\" + missing_parents.CreatedProcessName ) missing_parents = _sort_df_by_time(missing_parents) if debug: @@ -209,10 +204,10 @@ def _split_file_path( path_col: str = "CreatedProcessFilePath", file_col: str = "CreatedProcessName", separator: str = "\\", -) -> Dict[str, Union[str, float]]: +) -> dict[str, str | float]: """Split file path in to folder/stem.""" - f_path: Union[str, float] = np.nan - f_stem: Union[str, float] = np.nan + f_path: str | float = np.nan + f_stem: str | float = np.nan try: f_path, _, f_stem = input_path.rpartition(separator) except AttributeError: @@ -223,9 +218,7 @@ def _split_file_path( def _extract_missing_gparents(data): """Return grandparent processes for any procs not in Createdprocesses.""" missing_gps = ( - data[~data.parent_key.isin(data.proc_key)] - .filter(regex=".*Parent.*") - .drop_duplicates() + data[~data.parent_key.isin(data.proc_key)].filter(regex=".*Parent.*").drop_duplicates() ) missing_gps_file_split = missing_gps.apply( lambda proc: _split_file_path(proc.CreatedProcessParentName), @@ -260,7 +253,7 @@ def _extract_missing_gparents(data): return missing_gps -def _get_par_child_col_mapping(data: pd.DataFrame) -> Dict[str, str]: +def _get_par_child_col_mapping(data: pd.DataFrame) -> dict[str, str]: """Return a mapping between parent and child column names.""" created_proc_cols = _remove_col_prefix(data, "Created") init_proc_cols = _remove_col_prefix(data, "Initiating") @@ -268,7 +261,7 @@ def _get_par_child_col_mapping(data: pd.DataFrame) -> Dict[str, str]: return {**init_proc_col_mapping, **_MDE_NON_STD_COL_MAP} -def _remove_col_prefix(data: pd.DataFrame, prefix: str) -> Dict[str, str]: +def _remove_col_prefix(data: pd.DataFrame, prefix: str) -> dict[str, str]: """Return a mapping of column stems and columns with `prefix`.""" return { col.replace(prefix, ""): col @@ -278,8 +271,8 @@ def _remove_col_prefix(data: pd.DataFrame, prefix: str) -> Dict[str, str]: def _map_columns( - created_cols: Dict[str, str], init_cols: Dict[str, str] -) -> Tuple[Dict[str, str], Dict[str, str]]: + created_cols: dict[str, str], init_cols: dict[str, str] +) -> tuple[dict[str, str], dict[str, str]]: """Return Initiating -> Created column mapping.""" col_mapping = {} unmapped = {} @@ -355,8 +348,8 @@ def convert_mde_schema_to_internal( data["CreatedProcessParentId"] = data[schema.parent_id] # Put a value in parent procs with no name - null_proc_parent = data[schema.parent_name] == "" # type: ignore - data.loc[null_proc_parent, schema.parent_name] = "unknown" # type: ignore + null_proc_parent = data[schema.parent_name] == "" + data.loc[null_proc_parent, schema.parent_name] = "unknown" # Extract InitiatingProc folder path - remove stem data["InitiatingProcessFolderPath"] = data.InitiatingProcessFolderPath.apply( @@ -368,8 +361,6 @@ def convert_mde_schema_to_internal( if isinstance(arg_value, str) and arg_value in _SENTINEL_MDE_MAP: plot_args[arg_name] = _SENTINEL_MDE_MAP[arg_value] if isinstance(arg_value, list): - plot_args[arg_name] = [ - _SENTINEL_MDE_MAP.get(field, field) for field in arg_value - ] + plot_args[arg_name] = [_SENTINEL_MDE_MAP.get(field, field) for field in arg_value] return data.rename(columns=_SENTINEL_MDE_MAP) diff --git a/msticpy/transform/proc_tree_build_winlx.py b/msticpy/transform/proc_tree_build_winlx.py index 1d343dd3d..8831568b9 100644 --- a/msticpy/transform/proc_tree_build_winlx.py +++ b/msticpy/transform/proc_tree_build_winlx.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Process Tree builder for Windows security and Linux auditd events.""" + from dataclasses import asdict -from typing import Tuple import pandas as pd @@ -53,10 +53,10 @@ def extract_process_tree( """ # Clean data - procs_cln, schema = _clean_proc_data(procs, schema) # type: ignore + procs_cln, schema = _clean_proc_data(procs, schema) # Merge parent-child - merged_procs = _merge_parent_by_time(procs_cln, schema) # type: ignore + merged_procs = _merge_parent_by_time(procs_cln, schema) if debug: _check_merge_status(procs_cln, merged_procs, schema) @@ -85,13 +85,11 @@ def extract_process_tree( def _clean_proc_data( procs: pd.DataFrame, - schema: "ProcSchema", # type: ignore # noqa: F821 -) -> Tuple[pd.DataFrame, ProcSchema]: + schema: ProcSchema, +) -> tuple[pd.DataFrame, ProcSchema]: """Return cleaned process data.""" procs = ensure_df_datetimes(procs, columns=schema.time_stamp) - procs_cln = ( - procs.drop_duplicates().sort_values(schema.time_stamp, ascending=True).copy() - ) + procs_cln = procs.drop_duplicates().sort_values(schema.time_stamp, ascending=True).copy() # Filter out any non-process events if schema.event_id_column and schema.event_id_identifier: @@ -102,7 +100,7 @@ def _clean_proc_data( if schema.logon_id not in procs_cln.columns: schema = ProcSchema(**(asdict(schema))) - schema.logon_id = None # type: ignore + schema.logon_id = None if schema.logon_id: procs_cln[Col.EffectiveLogonId] = procs_cln[schema.logon_id] @@ -121,14 +119,14 @@ def _clean_proc_data( if schema.parent_name: no_pproc = procs_cln[schema.parent_name] == "" procs_cln.loc[no_pproc, schema.parent_name] = "unknown" - procs_cln[Col.parent_proc_lc] = procs_cln[schema.parent_name].str.lower() # type: ignore + procs_cln[Col.parent_proc_lc] = procs_cln[schema.parent_name].str.lower() procs_cln[Col.source_index] = procs_cln.index return procs_cln, schema def _num_cols_to_str( procs_cln: pd.DataFrame, - schema: "ProcSchema", # type: ignore # noqa: F821 + schema: ProcSchema, ) -> pd.DataFrame: """ Change any numeric columns in our core schema to strings. @@ -140,9 +138,7 @@ def _num_cols_to_str( into a single string. """ # Change float/int cols in our core schema to force int - schema_cols = [ - col for col in asdict(schema).values() if col and col in procs_cln.columns - ] + schema_cols = [col for col in asdict(schema).values() if col and col in procs_cln.columns] force_int_cols = { col: "int" for col, col_type in procs_cln[schema_cols].dtypes.to_dict().items() @@ -161,7 +157,7 @@ def _num_cols_to_str( def _merge_parent_by_time( procs: pd.DataFrame, - schema: "ProcSchema", # type: ignore # noqa: F821 + schema: ProcSchema, ) -> pd.DataFrame: """Merge procs with parents using merge_asof.""" parent_procs = ( @@ -204,7 +200,8 @@ def _merge_parent_by_time( def _extract_inferred_parents( - merged_procs: pd.DataFrame, schema: "ProcSchema" # type: ignore # noqa: F821 + merged_procs: pd.DataFrame, + schema: ProcSchema, ) -> pd.DataFrame: """Find any inferred parents and creates rows for them.""" tz_aware = merged_procs.iloc[0][schema.time_stamp].tz @@ -212,9 +209,7 @@ def _extract_inferred_parents( # Fill in missing values for root processes root_procs_crit = merged_procs[Col.source_index_par].isna() - merged_procs.loc[root_procs_crit, "NewProcessId_par"] = merged_procs[ - schema.parent_id - ] + merged_procs.loc[root_procs_crit, "NewProcessId_par"] = merged_procs[schema.parent_id] parent_col_name = schema.parent_name or "ParentName" if schema.parent_name: merged_procs.loc[root_procs_crit, Col.new_process_lc_par] = merged_procs[ @@ -312,9 +307,7 @@ def _check_merge_status(procs, merged_procs, schema): print("These two should add up to top line") row_dups = len(rows_with_dups2) print("Rows with dups", row_dups) - row_nodups = len( - merged_procs[~merged_procs[Col.source_index].isin(rows_with_dups2)] - ) + row_nodups = len(merged_procs[~merged_procs[Col.source_index].isin(rows_with_dups2)]) print("Rows with no dups", row_nodups) print(row_dups, "+", row_nodups, "=", row_dups + row_nodups) @@ -333,20 +326,14 @@ def _check_inferred_parents(procs, procs_par): def _check_proc_keys(merged_procs_par, schema): """Diagnostic for _assign_proc_keys.""" - crit1 = merged_procs_par[Col.timestamp_orig_par].isin( - merged_procs_par[schema.time_stamp] - ) - crit2 = merged_procs_par[Col.EffectiveLogonId].isin( - merged_procs_par[schema.logon_id] - ) + crit1 = merged_procs_par[Col.timestamp_orig_par].isin(merged_procs_par[schema.time_stamp]) + crit2 = merged_procs_par[Col.EffectiveLogonId].isin(merged_procs_par[schema.logon_id]) c2a = None if schema.target_logon_id: c2a = merged_procs_par[Col.EffectiveLogonId].isin( merged_procs_par[schema.target_logon_id] ) - crit3 = merged_procs_par[Col.parent_proc_lc].isin( - merged_procs_par[Col.new_process_lc] - ) + crit3 = merged_procs_par[Col.parent_proc_lc].isin(merged_procs_par[Col.new_process_lc]) crit4 = merged_procs_par[schema.process_id].isin(merged_procs_par[schema.parent_id]) crit5 = merged_procs_par[Col.parent_key].isin(merged_procs_par.index) crit6 = merged_procs_par[Col.parent_key].isna() diff --git a/msticpy/transform/proc_tree_builder.py b/msticpy/transform/proc_tree_builder.py index 98d30892d..0a702bfc7 100644 --- a/msticpy/transform/proc_tree_builder.py +++ b/msticpy/transform/proc_tree_builder.py @@ -4,7 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Process Tree Builder module for Process Tree Visualization.""" -from typing import Any, Dict, Optional, Union + +from __future__ import annotations + +from typing import Any import pandas as pd @@ -13,16 +16,11 @@ from . import proc_tree_build_winlx as winlx # pylint: disable=unused-import -from .proc_tree_schema import ProcSchema # noqa: F401 from .proc_tree_schema import ( # noqa: F401 - HX_PROCESSEVENT_SCH, - LX_EVENT_SCH, MDE_EVENT_SCH, MDE_INT_EVENT_SCH, - OSQUERY_EVENT_SCH, SUPPORTED_SCHEMAS, - SYSMON_PROCESS_CREATE_EVENT_SCH, - WIN_EVENT_SCH, + ProcSchema, # noqa: F401 ) from .proc_tree_schema import ColNames as Col from .process_tree_utils import get_summary_info @@ -33,7 +31,7 @@ def build_process_tree( procs: pd.DataFrame, - schema: Union[ProcSchema, Dict[str, Any]] = None, + schema: ProcSchema | dict[str, Any] | None = None, show_summary: bool = False, debug: bool = False, **kwargs, @@ -45,7 +43,7 @@ def build_process_tree( ---------- procs : pd.DataFrame Process events (Windows 4688 or Linux Auditd) - schema : Union[ProcSchema, Dict[str, Any]], optional + schema : ProcSchema | dict[str, Any] | None, optional The column schema to use, by default None. If supplied as a dict it must include definitions for the required fields in the ProcSchema class @@ -99,13 +97,13 @@ def build_process_tree( return proc_tree.sort_values(by=["path", schema.time_stamp], ascending=True) -def infer_schema(data: Union[pd.DataFrame, pd.Series]) -> Optional[ProcSchema]: +def infer_schema(data: pd.DataFrame | pd.Series) -> ProcSchema | None: """ Infer the correct schema to use for this data set. Parameters ---------- - data : Union[pd.DataFrame, pd.Series] + data : pd.DataFrame | pd.Series Data set to test Returns diff --git a/msticpy/transform/proc_tree_schema.py b/msticpy/transform/proc_tree_schema.py index 70c921d44..820dfbd0c 100644 --- a/msticpy/transform/proc_tree_schema.py +++ b/msticpy/transform/proc_tree_schema.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Process Tree Schema module for Process Tree Visualization.""" + from __future__ import annotations from dataclasses import MISSING, asdict, dataclass, field, fields @@ -29,7 +30,7 @@ class ProcessTreeSchemaException(MsticpyUserError): @dataclass -class ProcSchema: # pylint: disable=too-many-instance-attributes +class ProcSchema: # pylint: disable=too-many-instance-attributes # noqa: PLW1641 """ Property name lookup for Process event schema. @@ -64,8 +65,7 @@ def __eq__(self: Self, other: object) -> bool: self_dict: dict[str, Any] = asdict(self) return not any( - value and value != self_dict[field] - for field, value in asdict(other).items() + value and value != self_dict[field] for field, value in asdict(other).items() ) @property @@ -156,9 +156,7 @@ def blank_schema_dict(cls: type[Self]) -> dict[str, Any]: """Return blank schema dictionary.""" return { cls_field.name: ( - "required" - if (cls_field.default or cls_field.default == MISSING) - else None + "required" if (cls_field.default or cls_field.default == MISSING) else None ) for cls_field in fields(cls) } diff --git a/msticpy/transform/process_tree_utils.py b/msticpy/transform/process_tree_utils.py index d867d9992..8f12272f4 100644 --- a/msticpy/transform/process_tree_utils.py +++ b/msticpy/transform/process_tree_utils.py @@ -4,9 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Process Tree Visualization.""" + import textwrap from collections import Counter -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, NamedTuple import pandas as pd @@ -35,7 +36,7 @@ def get_process_key(procs: pd.DataFrame, source_index: int) -> str: The process key of the process. """ - return procs[procs[Col.source_index] == source_index].iloc[0].name # type: ignore + return procs[procs[Col.source_index] == source_index].iloc[0].name # def build_process_key( # type: ignore # noqa: F821 @@ -85,7 +86,7 @@ def get_roots(procs: pd.DataFrame) -> pd.DataFrame: return procs[procs["IsRoot"]] -def get_process(procs: pd.DataFrame, source: Union[str, pd.Series]) -> pd.Series: +def get_process(procs: pd.DataFrame, source: str | pd.Series) -> pd.Series: """ Return the process event as a Series. @@ -114,9 +115,7 @@ def get_process(procs: pd.DataFrame, source: Union[str, pd.Series]) -> pd.Series raise ValueError("Unknown type for source parameter.") -def get_parent( - procs: pd.DataFrame, source: Union[str, pd.Series] -) -> Optional[pd.Series]: +def get_parent(procs: pd.DataFrame, source: str | pd.Series) -> pd.Series | None: """ Return the parent of the source process. @@ -139,7 +138,7 @@ def get_parent( return None -def get_root(procs: pd.DataFrame, source: Union[str, pd.Series]) -> pd.Series: +def get_root(procs: pd.DataFrame, source: str | pd.Series) -> pd.Series: """ Return the root process for the source process. @@ -162,7 +161,7 @@ def get_root(procs: pd.DataFrame, source: Union[str, pd.Series]) -> pd.Series: return root_proc.iloc[0] -def get_root_tree(procs: pd.DataFrame, source: Union[str, pd.Series]) -> pd.DataFrame: +def get_root_tree(procs: pd.DataFrame, source: str | pd.Series) -> pd.DataFrame: """ Return the process tree to which the source process belongs. @@ -203,7 +202,7 @@ def get_tree_depth(procs: pd.DataFrame) -> int: def get_children( - procs: pd.DataFrame, source: Union[str, pd.Series], include_source: bool = True + procs: pd.DataFrame, source: str | pd.Series, include_source: bool = True ) -> pd.DataFrame: """ Return the child processes for the source process. @@ -235,7 +234,7 @@ def get_children( def get_descendents( procs: pd.DataFrame, - source: Union[str, pd.Series], + source: str | pd.Series, include_source: bool = True, max_levels: int = -1, ) -> pd.DataFrame: @@ -265,7 +264,7 @@ def get_descendents( parent_keys = [proc.name] level = 0 current_index_name = procs.index.name - rem_procs: Optional[pd.DataFrame] = None + rem_procs: pd.DataFrame | None = None while max_levels == -1 or level < max_levels: if rem_procs is not None: # pylint: disable=unsubscriptable-object @@ -278,7 +277,7 @@ def get_descendents( if children.empty: break descendents.append(children) - parent_keys = children.index # type: ignore + parent_keys = children.index level += 1 if descendents: @@ -322,7 +321,7 @@ def get_ancestors(procs: pd.DataFrame, source, include_source=True) -> pd.DataFr def get_siblings( - procs: pd.DataFrame, source: Union[str, pd.Series], include_source: bool = True + procs: pd.DataFrame, source: str | pd.Series, include_source: bool = True ) -> pd.DataFrame: """ Return the processes that share the parent of the source process. @@ -344,13 +343,13 @@ def get_siblings( """ parent = get_parent(procs, source) proc = get_process(procs, source) - siblings = get_children(procs, parent, include_source=False) # type: ignore + siblings = get_children(procs, parent, include_source=False) if not include_source: return siblings[siblings.index != proc.name] return siblings -def get_summary_info(procs: pd.DataFrame) -> Dict[str, int]: +def get_summary_info(procs: pd.DataFrame) -> dict[str, int]: """ Return summary information about the process trees. @@ -365,7 +364,7 @@ def get_summary_info(procs: pd.DataFrame) -> Dict[str, int]: Summary statistic about the process tree """ - summary: Dict[str, Any] = {} + summary: dict[str, Any] = {} summary["Processes"] = len(procs) summary["RootProcesses"] = len(procs[procs["IsRoot"]]) summary["LeafProcesses"] = len(procs[procs["IsLeaf"]]) @@ -386,14 +385,14 @@ class TemplateLine(NamedTuple): """ - items: List[Tuple[str, str]] = [] + items: list[tuple[str, str]] = [] wrap: int = 80 def tree_to_text( procs: pd.DataFrame, - schema: Optional[Union[ProcSchema, Dict[str, str]]] = None, - template: Optional[List[TemplateLine]] = None, + schema: ProcSchema | dict[str, str] | None = None, + template: list[TemplateLine] | None = None, sort_column: str = "path", wrap_column: int = 0, ) -> str: @@ -427,11 +426,9 @@ def tree_to_text( """ if not schema and not template: - raise ValueError( - "One of 'schema' and 'template' must be supplied", "as parameters." - ) + raise ValueError("One of 'schema' and 'template' must be supplied", "as parameters.") template = template or _create_proctree_template(schema) # type: ignore - output: List[str] = [] + output: list[str] = [] for _, row in procs.sort_values(sort_column).iterrows(): depth_count = Counter(row.path).get("/", 0) header = _node_header(depth_count) @@ -439,8 +436,7 @@ def tree_to_text( # handle first row separately since it needs a header tmplt_line = template[0] out_line = " ".join( - f"{name}: {row[col]}" if name else f"{row[col]}" - for name, col in tmplt_line.items + f"{name}: {row[col]}" if name else f"{row[col]}" for name, col in tmplt_line.items ) indent = " " * len(header) + " " out_line = "\n".join( @@ -454,9 +450,7 @@ def tree_to_text( # process subsequent rows for tmplt_line in template[1:]: - out_line = " ".join( - f"{name}: {row[col]}" for name, col in tmplt_line.items - ) + out_line = " ".join(f"{name}: {row[col]}" for name, col in tmplt_line.items) out_line = "\n".join( textwrap.wrap( out_line, @@ -471,15 +465,13 @@ def tree_to_text( def _create_proctree_template( - schema: Union[ProcSchema, Dict[str, str]] -) -> List[TemplateLine]: + schema: ProcSchema | dict[str, str], +) -> list[TemplateLine]: """Create a template from the schema.""" if isinstance(schema, dict): schema = ProcSchema(**schema) - template_lines: List[TemplateLine] = [ - TemplateLine( - items=[("Process", schema.process_name), ("PID", schema.process_id)] - ), + template_lines: list[TemplateLine] = [ + TemplateLine(items=[("Process", schema.process_name), ("PID", schema.process_id)]), TemplateLine(items=[("Time", schema.time_stamp)]), ] if schema.cmd_line: diff --git a/msticpy/vis/__init__.py b/msticpy/vis/__init__.py index aa6f155b8..16c566a1f 100644 --- a/msticpy/vis/__init__.py +++ b/msticpy/vis/__init__.py @@ -22,6 +22,7 @@ - timeseries - timeseries analysis visualization """ + # flake8: noqa: F403 # pylint: disable=unused-import from . import mp_pandas_plot diff --git a/msticpy/vis/code_view.py b/msticpy/vis/code_view.py index 5010a0560..291657eb6 100644 --- a/msticpy/vis/code_view.py +++ b/msticpy/vis/code_view.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- """Display code with with highlighting.""" -from typing import List, Optional from IPython.display import HTML, DisplayHandle, display from pygments import formatters, highlight, lexers, styles @@ -52,7 +51,7 @@ def to_html(code: str, language: str, style: str = "default", full: bool = True) # pylint: enable=no-member -def list_pygments_styles() -> List[str]: +def list_pygments_styles() -> list[str]: """ Return list of pygments styles available. @@ -71,7 +70,7 @@ def display_html( style: str = "stata-dark", full: bool = True, display_handle: bool = False, -) -> Optional[DisplayHandle]: +) -> DisplayHandle | None: """ Display pygments-formatted code. diff --git a/msticpy/vis/data_viewer.py b/msticpy/vis/data_viewer.py index 1eee5c0d8..f028858b2 100644 --- a/msticpy/vis/data_viewer.py +++ b/msticpy/vis/data_viewer.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Dataframe viewer.""" + from collections import namedtuple -from typing import Dict, List, Union import ipywidgets as widgets import pandas as pd @@ -45,9 +45,7 @@ class DataViewerBokeh: _DEF_HEIGHT = 550 - def __init__( - self, data: pd.DataFrame, selected_cols: List[str] = None, debug=False - ): + def __init__(self, data: pd.DataFrame, selected_cols: list[str] = None, debug=False): """ Initialize the DataViewer class. @@ -104,16 +102,14 @@ def __init__( @property def filtered_data(self) -> pd.DataFrame: """Return filtered dataframe.""" - return self.data_filter.filtered_dataframe[ # type: ignore - self.column_chooser.selected_columns - ] # type: ignore + return self.data_filter.filtered_dataframe[self.column_chooser.selected_columns] @property - def filters(self) -> Dict[str, FilterExpr]: + def filters(self) -> dict[str, FilterExpr]: """Return current filters as a dict.""" return self.data_filter.filters - def import_filters(self, filters: Dict[str, FilterExpr]): + def import_filters(self, filters: dict[str, FilterExpr]): """ Import filter set replacing current filters. @@ -165,12 +161,8 @@ def _apply_filter(self, btn): del btn if self._debug: print("_apply_filter") - self.data_table.view = CDSView( - filter=BooleanFilter(self.data_filter.bool_filters) - ) - self.data_table.height = self._calc_df_height( - self.data_filter.filtered_dataframe - ) + self.data_table.view = CDSView(filter=BooleanFilter(self.data_filter.bool_filters)) + self.data_table.height = self._calc_df_height(self.data_filter.filtered_dataframe) self._update_data_table() @@ -275,13 +267,11 @@ def __init__(self, data: pd.DataFrame): self._not_cb = widgets.Checkbox( description="not", value=False, **(_layout("60px", desc_width="initial")) ) - self._filter_value = widgets.Textarea( - description="Filter value", **(_layout("400px")) - ) + self._filter_value = widgets.Textarea(description="Filter value", **(_layout("400px"))) self._curr_filters = widgets.Select(description="Filters", **(_layout("500px"))) self._oper_label = widgets.Label(" in ") - self.filters: Dict[str, FilterExpr] = {} + self.filters: dict[str, FilterExpr] = {} self._curr_filters.observe(self._select_filter, names="value") self._col_select.observe(self._update_operators, names="value") @@ -336,7 +326,7 @@ def _ipython_display_(self): """Display in IPython.""" self.display() - def import_filters(self, filters: Dict[str, FilterExpr]): + def import_filters(self, filters: dict[str, FilterExpr]): """ Replace the current filters with `filters`. @@ -348,9 +338,7 @@ def import_filters(self, filters: Dict[str, FilterExpr]): column [str], inv [bool], operator [str], expr [str] """ - self.filters = { - f_name: FilterExpr(*f_expr) for f_name, f_expr in filters.items() - } + self.filters = {f_name: FilterExpr(*f_expr) for f_name, f_expr in filters.items()} self._curr_filters.options = list(filters.keys()) @property @@ -358,9 +346,7 @@ def bool_filters(self): """Return current set of boolean filters.""" df_filt = None for filt in self.filters.values(): - new_filt = self._make_filter( - filt.column, filt.operator, filt.expr, filt.inv - ) + new_filt = self._make_filter(filt.column, filt.operator, filt.expr, filt.inv) new_filt = new_filt.values if isinstance(new_filt, pd.Series) else new_filt df_filt = new_filt if df_filt is None else df_filt & new_filt return df_filt if df_filt is not None else self.data.index.isin(self.data.index) @@ -442,7 +428,9 @@ def _make_filter(self, col, operator, expr, not_true): return self._create_filter(col, operator, expr) # pylint: disable=too-many-return-statements - def _create_filter(self, col: str, operator: str, expr: str) -> pd.Series: + def _create_filter( # noqa: PLR0911 + self, col: str, operator: str, expr: str + ) -> pd.Series: if operator == "query": return pd.Series(self.data.index.isin(self.data.query(expr).index)) if operator in ("in", "between"): @@ -463,23 +451,20 @@ def _create_filter(self, col: str, operator: str, expr: str) -> pd.Series: return self.data[col] < test_expr if operator == "<=": return self.data[col] >= test_expr - raise TypeError( - f"Unsupported operator for operator {operator} and column {col}" - ) + raise TypeError(f"Unsupported operator for operator {operator} and column {col}") def _filter_in_or_between(self, col: str, operator: str, expr: str) -> pd.Series: """Return filter for `in` and `between` operators.""" - test_expr: List[Union[str, int, float]] + test_expr: list[str | int | float] if pd.api.types.is_string_dtype(self.data[col]): test_expr = [item.strip("\"' ") for item in expr.split(",")] elif pd.api.types.is_numeric_dtype(self.data[col]): test_expr = [ - int(item) if "." not in item else float(item) - for item in expr.split(",") + int(item) if "." not in item else float(item) for item in expr.split(",") ] elif pd.api.types.is_datetime64_any_dtype(self.data[col]): - test_expr = [pd.Timestamp(item.strip()) for item in expr.split(",")] # type: ignore + test_expr = [pd.Timestamp(item.strip()) for item in expr.split(",")] else: raise TypeError( f"Unsupported column type {self.data[col].dtype}", @@ -496,11 +481,11 @@ def _filter_in_or_between(self, col: str, operator: str, expr: str) -> pd.Series def _conv_expr_type(self, col: str, expr: str): """Convert string expression to required type.""" - test_expr: Union[str, int, float] + test_expr: str | int | float if pd.api.types.is_numeric_dtype(self.data[col]): test_expr = int(expr) if "." not in expr else float(expr) elif pd.api.types.is_datetime64_any_dtype(self.data[col]): - test_expr = pd.Timestamp(expr.strip()) # type: ignore + test_expr = pd.Timestamp(expr.strip()) elif pd.api.types.is_string_dtype(self.data[col]): test_expr = expr.strip("\"' ") else: diff --git a/msticpy/vis/data_viewer_panel.py b/msticpy/vis/data_viewer_panel.py index b554f7add..3c7179843 100644 --- a/msticpy/vis/data_viewer_panel.py +++ b/msticpy/vis/data_viewer_panel.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- """Dataframe viewer using Panel Tabulator.""" + +from collections.abc import Callable, Iterable from functools import partial from pprint import pformat from textwrap import wrap -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any import pandas as pd from IPython import get_ipython @@ -32,7 +34,7 @@ class DataViewer: _DEF_HEIGHT = 550 _DEFAULT_HIDDEN_COLS = ["TenantId"] - def __init__(self, data: pd.DataFrame, selected_cols: List[str] = None, **kwargs): + def __init__(self, data: pd.DataFrame, selected_cols: list[str] = None, **kwargs): """ Initialize the DataViewer class. @@ -97,10 +99,7 @@ def __init__(self, data: pd.DataFrame, selected_cols: List[str] = None, **kwargs hidden_cols = kwargs.pop("hidden_cols", None) self._hidden_columns = self._default_hidden_cols(selected_cols, hidden_cols) - if ( - not kwargs.pop("show_tenant_id", False) - and "TenantId" in self._hidden_columns - ): + if not kwargs.pop("show_tenant_id", False) and "TenantId" in self._hidden_columns: self._hidden_columns.remove("TenantId") # Create the tabulator control @@ -119,13 +118,10 @@ def __init__(self, data: pd.DataFrame, selected_cols: List[str] = None, **kwargs # Add the column chooser self.column_chooser = DataTableColumnChooser( data, - selected_cols=selected_cols - or list(set(data.columns) - set(self._hidden_columns)), # type: ignore + selected_cols=selected_cols or list(set(data.columns) - set(self._hidden_columns)), ) self.column_chooser.apply_button.on_click(self._update_columns) - self.accordion = pn.layout.Accordion( - ("Select columns", self.column_chooser.layout) - ) + self.accordion = pn.layout.Accordion(("Select columns", self.column_chooser.layout)) self._update_columns(btn=None) # set layout for the widget. self.layout = pn.layout.Column(self.data_table, self.accordion) @@ -154,15 +150,15 @@ def _update_columns(self, btn): self.accordion.active = [] def _create_row_formatter( - self, detail_columns: Optional[List[str]] = None - ) -> Optional[Callable]: + self, detail_columns: list[str] | None = None + ) -> Callable | None: """Build formatter function for row-details.""" if not detail_columns: return None row_view_cols = set(detail_columns) & set(self.data.columns) return partial(_display_column_details, columns=row_view_cols) - def _create_configuration(self, kwargs) -> Dict[str, Any]: + def _create_configuration(self, kwargs) -> dict[str, Any]: """Create Tabulator configuration dict to pass to JS Tabulator.""" return { "columnDefaults": {"maxWidth": kwargs.pop("max_col_width", 500)}, @@ -177,7 +173,7 @@ def _create_configuration(self, kwargs) -> Dict[str, Any]: }, } - def _default_hidden_cols(self, selected_cols, hidden_cols) -> List[str]: + def _default_hidden_cols(self, selected_cols, hidden_cols) -> list[str]: """Return list of of columns hidden by default.""" return [ hidden_col @@ -208,7 +204,7 @@ def __init__(self, data, selected_cols=None): self.layout = pn.layout.Column(self._col_select, self.apply_button) @property - def selected_columns(self) -> List[str]: + def selected_columns(self) -> list[str]: """Return a list of Bokeh column definitions for the DataFrame.""" return self._col_select.value @@ -217,7 +213,7 @@ def dataframe_columns(self): """Return the selected set of DataFrame columns.""" return self.data[self._reorder_cols(self.selected_columns)] - def _reorder_cols(self, columns: List[str]) -> List[str]: + def _reorder_cols(self, columns: list[str]) -> list[str]: """Return column list in original order.""" # order the columns as originally specified (or as the DF) col_init = [col for col in self._initial_cols if col in columns] diff --git a/msticpy/vis/entity_graph_tools.py b/msticpy/vis/entity_graph_tools.py index 895187862..c41e9dc42 100644 --- a/msticpy/vis/entity_graph_tools.py +++ b/msticpy/vis/entity_graph_tools.py @@ -7,7 +7,6 @@ from datetime import datetime, timezone from importlib.metadata import version -from typing import List, Optional, Union import networkx as nx import numpy as np @@ -17,7 +16,7 @@ from bokeh.models import Circle, HoverTool, Label, LayoutDOM # type: ignore from bokeh.plotting import figure, from_networkx from dateutil import parser -from packaging.version import Version, parse +from packaging.version import Version, parse # pylint: disable=no-name-in-module from .._version import VERSION from ..common.exceptions import MsticpyUserError @@ -51,7 +50,7 @@ class EntityGraph: def __init__( self, - entity: Union[Incident, Alert, pd.DataFrame, pd.Series, Entity, SecurityAlert], + entity: Incident | Alert | pd.DataFrame | pd.Series | Entity | SecurityAlert, ): """ Create a new instance of the entity graph. @@ -65,7 +64,7 @@ def __init__( """ output_notebook() self.alertentity_graph = nx.Graph(id="IncidentGraph") - if isinstance(entity, (Incident, Alert)): + if isinstance(entity, Incident | Alert): self._add_incident_or_alert_node(entity) elif isinstance(entity, pd.DataFrame): self.add_incident(entity) @@ -74,7 +73,7 @@ def __init__( elif isinstance(entity, Entity): self._add_entity_node(entity) elif isinstance(entity, SecurityAlert): - entity = Alert(entity) # type: ignore + entity = Alert(entity) self._add_incident_or_alert_node(entity) def plot(self, hide: bool = False, timeline: bool = False, **kwargs) -> LayoutDOM: @@ -198,7 +197,7 @@ def add_entity(self, ent: Entity, attached_to: str = None): """ self._add_entity_node(ent, attached_to) - def add_incident(self, incident: Union[Incident, Alert, pd.DataFrame]): + def add_incident(self, incident: Incident | Alert | pd.DataFrame): """ Add another incident or set of incidents to the graph. @@ -212,7 +211,7 @@ def add_incident(self, incident: Union[Incident, Alert, pd.DataFrame]): if isinstance(incident, pd.DataFrame): for row in incident.iterrows(): if "name" in row[1]: - inc = Incident(src_event=row[1]) # type: ignore + inc = Incident(src_event=row[1]) elif "AlertName" in row[1]: inc = Alert(src_event=row[1]) # type: ignore self._add_incident_or_alert_node(inc) @@ -222,8 +221,8 @@ def add_incident(self, incident: Union[Incident, Alert, pd.DataFrame]): def add_note( self, name: str, - description: Optional[str] = None, - attached_to: Union[str, List] = None, + description: str | None = None, + attached_to: str | list | None = None, ): """ Add a node to the graph representing a note or comment. @@ -279,9 +278,7 @@ def add_link(self, source: str, target: str): self.alertentity_graph.add_edge(source, target) else: missing = [ - name - for name in [source, target] - if name not in self.alertentity_graph.nodes() + name for name in [source, target] if name not in self.alertentity_graph.nodes() ] raise MsticpyUserError(title=f"Node(s) {missing} not found in graph") @@ -309,9 +306,7 @@ def remove_link(self, source: str, target: str): ): self.alertentity_graph.remove_edge(source, target) else: - raise MsticpyUserError( - title=f"No edge exists between {source} and {target}" - ) + raise MsticpyUserError(title=f"No edge exists between {source} and {target}") def remove_node(self, name: str): """ @@ -346,7 +341,7 @@ def to_df(self) -> pd.DataFrame: ] return pd.DataFrame(node_list).replace("None", np.nan) - def _add_incident_or_alert_node(self, incident: Union[Incident, Alert, None]): + def _add_incident_or_alert_node(self, incident: Incident | Alert | None): """Check what type of entity is passed in and creates relevant graph.""" if isinstance(incident, Incident): self._add_incident_node(incident) @@ -370,9 +365,7 @@ def _add_alert_node(self, alert, incident_name=None): def _add_incident_node(self, incident): """Add an incident entity to the graph.""" - self.alertentity_graph = nx.compose( - self.alertentity_graph, incident.to_networkx() - ) + self.alertentity_graph = nx.compose(self.alertentity_graph, incident.to_networkx()) if incident.Alerts: for alert in incident.Alerts: self._add_alert_node(alert, incident.name_str) @@ -398,7 +391,7 @@ def graph(self) -> nx.Graph: return self.alertentity_graph -def _convert_to_tz_aware_ts(date_string: Optional[str]) -> Optional[datetime]: +def _convert_to_tz_aware_ts(date_string: str | None) -> datetime | None: """Convert a date string to a timezone aware datetime object.""" if date_string is None: return None @@ -424,7 +417,7 @@ def _dedupe_entities(alerts, ents) -> list: def plot_entitygraph( # pylint: disable=too-many-locals entity_graph: nx.Graph, node_size: int = 25, - font_size: Union[int, str] = 10, + font_size: int | str = 10, height: int = 800, width: int = 800, scale: int = 2, @@ -510,9 +503,7 @@ def plot_entitygraph( # pylint: disable=too-many-locals nx.set_node_attributes(entity_graph_for_plotting, node_attributes) for source_node, target_node in entity_graph.edges: - entity_graph_for_plotting.add_edge( - rev_index[source_node], rev_index[target_node] - ) + entity_graph_for_plotting.add_edge(rev_index[source_node], rev_index[target_node]) graph_renderer = from_networkx( entity_graph_for_plotting, nx.spring_layout, scale=scale, center=(0, 0) diff --git a/msticpy/vis/figure_dimension.py b/msticpy/vis/figure_dimension.py index 33a4ea65f..7ee7f9202 100644 --- a/msticpy/vis/figure_dimension.py +++ b/msticpy/vis/figure_dimension.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """figure_dimension - helps set the width and height properties of a figure for plotting.""" + +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any from bokeh.plotting import figure @@ -32,13 +34,13 @@ def set_figure_size(fig: figure, width: int, height: int) -> figure: """ if hasattr(figure(), "height"): - setattr(fig, "height", height) + fig.height = height if hasattr(figure(), "plot_height"): - setattr(fig, "plot_height", height) + fig.plot_height = height # type: ignore[attr-defined] if hasattr(figure(), "width"): - setattr(fig, "width", width) + fig.width = width # type: ignore[attr-defined] if hasattr(figure(), "plot_width"): - setattr(fig, "plot_width", width) + fig.plot_width = width # type: ignore[attr-defined] return fig @@ -71,9 +73,7 @@ def set_figure_size_params(*args, **kwargs): # pylint: disable=comparison-with-callable if func == figure: param_mapper = ( - _BOKEH_3_FIG_PARAMS - if hasattr(func(), "height") - else _BOKEH_2_FIG_PARAMS + _BOKEH_3_FIG_PARAMS if hasattr(func(), "height") else _BOKEH_2_FIG_PARAMS ) func_kwargs = { diff --git a/msticpy/vis/foliummap.py b/msticpy/vis/foliummap.py index 3924dc3f2..2605dcd29 100644 --- a/msticpy/vis/foliummap.py +++ b/msticpy/vis/foliummap.py @@ -5,6 +5,7 @@ # license information. # -------------------------------------------------------------------------- """Folium map class.""" + from __future__ import annotations import contextlib @@ -12,7 +13,8 @@ import math import statistics as stats import sys -from typing import Any, Callable, Generator, Iterable +from collections.abc import Callable, Generator, Iterable +from typing import Any import folium import pandas as pd @@ -126,8 +128,8 @@ def add_ip_cluster( continue if ( not ( - isinstance(ip_entity.Location.Latitude, (int, float)) - and isinstance(ip_entity.Location.Longitude, (int, float)) + isinstance(ip_entity.Location.Latitude, int | float) + and isinstance(ip_entity.Location.Longitude, int | float) ) or math.isnan(ip_entity.Location.Latitude) or math.isnan(ip_entity.Location.Longitude) @@ -150,9 +152,7 @@ def add_ip_cluster( else: marker_target_map: folium.Map = self.folium_map marker.add_to(marker_target_map) - self.locations.append( - (ip_entity.Location.Latitude, ip_entity.Location.Longitude) - ) + self.locations.append((ip_entity.Location.Latitude, ip_entity.Location.Longitude)) def add_ips( self: Self, @@ -177,9 +177,7 @@ def add_ips( _, ip_entities = _GEO_LITE.lookup_ip(ip_addr_list=ip_addresses) self.add_ip_cluster(ip_entities=ip_entities, **kwargs) - def add_geoloc_cluster( - self: Self, geo_locations: Iterable[GeoLocation], **kwargs - ) -> None: + def add_geoloc_cluster(self: Self, geo_locations: Iterable[GeoLocation], **kwargs) -> None: """ Add a collection of GeoLocation objects to the map. @@ -200,9 +198,7 @@ def add_geoloc_cluster( ] self.add_ip_cluster(ip_entities=ip_entities, **kwargs) - def add_locations( - self: Self, locations: Iterable[tuple[float, float]], **kwargs - ) -> None: + def add_locations(self: Self, locations: Iterable[tuple[float, float]], **kwargs) -> None: """ Add a collection of lat/long tuples to the map. @@ -614,11 +610,7 @@ def _get_popup_text(ip_entity: IpAddress) -> str: str(line) for line in [ ip_entity.Address, - *( - list( - ip_entity.Location.properties.values() if ip_entity.Location else [] - ) - ), + *(list(ip_entity.Location.properties.values() if ip_entity.Location else [])), *(list(ip_entity.AdditionalData.items())), ] ) @@ -634,9 +626,7 @@ def _get_popup_text(ip_entity: IpAddress) -> str: ) else: - from typing import Dict, Union - - IconMapper = Union[Callable[[str], Dict[str, Any]], Dict[str, Any], None] + IconMapper = Callable[[str], dict[str, Any]] | dict[str, Any] | None # pylint: disable=too-many-locals, too-many-arguments @@ -1010,7 +1000,7 @@ def get_map_center(entities: Iterable[Entity], mode: str = "modal"): loc_props: list[str] = [ p_name for p_name, p_val in entities[0].properties.items() - if isinstance(p_val, (IpAddress, GeoLocation)) + if isinstance(p_val, IpAddress | GeoLocation) ] for entity, prop in itertools.product(entities, loc_props): if prop not in entity: @@ -1068,9 +1058,7 @@ def _extract_coords_loc_entities( ) -> list[tuple[float, float]]: """Return list of coordinate tuples from GeoLocation entities.""" return [ - (loc.Latitude, loc.Longitude) - for loc in loc_entities - if loc.Latitude and loc.Longitude + (loc.Latitude, loc.Longitude) for loc in loc_entities if loc.Latitude and loc.Longitude ] diff --git a/msticpy/vis/matrix_plot.py b/msticpy/vis/matrix_plot.py index 2e4db77c4..b5d432a78 100644 --- a/msticpy/vis/matrix_plot.py +++ b/msticpy/vis/matrix_plot.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Bokeh matrix plot.""" + import math -from typing import List, Optional, Union import attr import numpy as np @@ -32,38 +32,38 @@ class PlotParams: """Plot params for time_duration.""" - title: Optional[str] = "Interaction Plot" - x: Optional[str] = None - x_col: Optional[str] = None - y: Optional[str] = None - y_col: Optional[str] = None + title: str | None = "Interaction Plot" + x: str | None = None + x_col: str | None = None + y: str | None = None + y_col: str | None = None intersect: bool = False height: int = 700 width: int = 900 color: str = "red" - value_col: Optional[str] = None + value_col: str | None = None dist_count: bool = False log_size: bool = False invert: bool = False - sort: Optional[Union[str, bool]] = None - sort_x: Optional[Union[str, bool]] = None - sort_y: Optional[Union[str, bool]] = None + sort: str | bool | None = None + sort_x: str | bool | None = None + sort_y: str | bool | None = None hide: bool = False - font_size: Optional[int] = None + font_size: int | None = None max_label_font_size: int = 11 @property - def x_column(self) -> Optional[str]: + def x_column(self) -> str | None: """Return the current x column value.""" return self.x or self.x_col @property - def y_column(self) -> Optional[str]: + def y_column(self) -> str | None: """Return the current y column value.""" return self.y or self.y_col @classmethod - def field_list(cls) -> List[str]: + def field_list(cls) -> list[str]: """Return field names as a list.""" return list(attr.fields_dict(cls).keys()) @@ -163,9 +163,7 @@ def plot_matrix(data: pd.DataFrame, **kwargs) -> LayoutDOM: plot_data = _prep_data(data, param) x_range = _sort_labels(plot_data, param.x_column, param.sort_x or param.sort) - y_range = _sort_labels( - plot_data, param.y_column, param.sort_y or param.sort, invert=True - ) + y_range = _sort_labels(plot_data, param.y_column, param.sort_y or param.sort, invert=True) # Rescale the size so that it matches the graph max_size = plot_data["size"].max() @@ -278,14 +276,14 @@ def _size_scale(value_series, log_size, invert): if param.value_col is None: # calculate a count of rows in each group - other_cols = list(set(data.columns) - set([param.x_column, param.y_column])) + other_cols = list(set(data.columns) - {param.x_column, param.y_column}) if other_cols: count_col = other_cols[0] else: count_col = data.index.name or "index" data = data.reset_index() count_rows_df = ( - data[[param.x_column, param.y_column, count_col]] # type: ignore + data[[param.x_column, param.y_column, count_col]] .groupby([param.x_column, param.y_column]) .count() .rename(columns={count_col: "row_count"}) @@ -299,20 +297,18 @@ def _size_scale(value_series, log_size, invert): if param.dist_count: # If distinct count of values required, get nunique tmp_df = ( - data[[param.x_column, param.y_column, param.value_col]] # type: ignore + data[[param.x_column, param.y_column, param.value_col]] .groupby([param.x_column, param.y_column]) .nunique() .reset_index() ) else: tmp_df = ( - data[[param.x_column, param.y_column, param.value_col]] # type: ignore + data[[param.x_column, param.y_column, param.value_col]] .groupby([param.x_column, param.y_column]) .sum() .reset_index() ) return tmp_df.assign( - size=lambda x: _size_scale( - tmp_df[param.value_col], param.log_size, param.invert - ) + size=lambda x: _size_scale(tmp_df[param.value_col], param.log_size, param.invert) ) diff --git a/msticpy/vis/mordor_browser.py b/msticpy/vis/mordor_browser.py index 62520b5cb..6c9b4afe4 100644 --- a/msticpy/vis/mordor_browser.py +++ b/msticpy/vis/mordor_browser.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Mordor dataset browser.""" + +from collections.abc import Iterable from pprint import pformat -from typing import Any, Dict, Iterable, Optional +from typing import Any import ipywidgets as widgets import pandas as pd @@ -28,7 +30,7 @@ class MordorBrowser: """Mordor browser widget.""" - def __init__(self, save_folder: Optional[str] = None, use_cached: bool = True): + def __init__(self, save_folder: str | None = None, use_cached: bool = True): """ Initialize MordorBrowser control. @@ -57,7 +59,7 @@ def __init__(self, save_folder: Optional[str] = None, use_cached: bool = True): "font_family": "arial, sans-serif", } - self.widgets: Dict[str, Any] = {} + self.widgets: dict[str, Any] = {} self._init_field_ctls() self._init_select_dataset() self._init_filter_ctrls() @@ -74,8 +76,8 @@ def __init__(self, save_folder: Optional[str] = None, use_cached: bool = True): list(self.fields.values()), layout=self.layouts["box_layout"] ) - self.datasets: Dict[str, pd.DataFrame] = {} - self.current_dataset: pd.DataFrame = None # type: ignore + self.datasets: dict[str, pd.DataFrame] = {} + self.current_dataset: pd.DataFrame = None display(widgets.VBox([browse_ctrls, fields_ctrls])) self._df_disp = display(HTML("

"), display_id=True) @@ -113,9 +115,7 @@ def _init_filter_ctrls(self): ) self.widgets["filter_text"].continuous_update = False self.widgets["filter_text"].observe(self._update_select_list, "value") - self.widgets["filter_help"] = widgets.Label( - value=" comma ORs values, '+' ANDs values" - ) + self.widgets["filter_help"] = widgets.Label(value=" comma ORs values, '+' ANDs values") # Mitre filters self.widgets["sel_techniques"] = widgets.SelectMultiple( @@ -143,9 +143,7 @@ def _init_filter_ctrls(self): self.widgets["filter_reset"].on_click(self._reset_filters) wgt_filter_grp = widgets.VBox( [ - widgets.HBox( - [self.widgets["filter_text"], self.widgets["filter_help"]] - ), + widgets.HBox([self.widgets["filter_text"], self.widgets["filter_help"]]), widgets.HBox( [ self.widgets["sel_techniques"], @@ -215,7 +213,7 @@ def _clear_fields(self): self.fields[field].value = "" self._clear_df_display() - def _select_ds_item(self, change): # noqa: MC0001 + def _select_ds_item(self, change): """Handle change of dataset selection.""" item_id = change.get("new") mdr_item = self.mdr_metadata.get(item_id) @@ -308,7 +306,7 @@ def _download_file(self, event): self._df_disp.update(self.datasets[selection]) @staticmethod - def _get_mitre_filter_options(mordor_index: Dict[str, MordorEntry], mitre_data): + def _get_mitre_filter_options(mordor_index: dict[str, MordorEntry], mitre_data): return [ (f"{m_id} - {mitre_data.loc[m_id].Name}", m_id) for m_id in mordor_index diff --git a/msticpy/vis/morph_charts.py b/msticpy/vis/morph_charts.py deleted file mode 100644 index d1a205072..000000000 --- a/msticpy/vis/morph_charts.py +++ /dev/null @@ -1,162 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Morph Charts class.""" -import json -from pathlib import Path - -import pandas as pd -import yaml -from deprecated.sphinx import deprecated -from IPython.display import IFrame - -from .._version import VERSION -from ..common.exceptions import MsticpyException - -__version__ = VERSION -__author__ = "Pete Bryan" - -_CHART_FOLDER = "morph_charts" - - -class MorphCharts: - """Create Morph Charts package data and render Morph Charts site.""" - - @deprecated("Morphchart functionality has been deprecated.", version="2.2.0") - def __init__(self): - """Create object and populate charts container.""" - self.charts = _get_charts(_CHART_FOLDER) - - def display(self, data: pd.DataFrame, chart_name: str) -> IFrame: - """ - Prepare package data and display MorphChart in an IFrame. - - Parameters - ---------- - data: pd.DataFrame: - A DataFrame of data for the morphchart to plot. - - chart_name: str: - The name of the Morph Chart to plot. - - """ - # Check input data is correct format and that the chart being requested exists - if not isinstance(data, pd.DataFrame): - raise MsticpyException("Data provided must be in pandas.DataFrame format") - - if chart_name not in self.charts: - raise MsticpyException( - f"{chart_name} is not a vaid chart. Run list_charts() to see avaliable charts" # pylint: disable=line-too-long - ) - - # Create description file with length of our data set - description_dict = self.charts[chart_name]["DescriptionFile"] - description_dict["tables"][0]["rows"] = len(data) - # Create output folder for package files - out_path = Path.cwd().joinpath(*["morphchart_package", "description.json"]) - Path.mkdir(Path.cwd().joinpath("morphchart_package"), exist_ok=True) - # Write description file - with open(out_path, "w", encoding="utf-8") as morph_file: - json.dump(description_dict, morph_file) - # Write dataset to query_data csv - data_out_path = out_path = Path.cwd().joinpath( - *["morphchart_package", "query_data.csv"] - ) - data.to_csv(data_out_path, index=False) - # Display Morph Charts in IFrame with instructions - print( - f"Navigate to {Path.cwd().joinpath('morphchart_package')} and upload the files below" - ) - print("Charts provided by http://morphcharts.com/") - return IFrame("http://morphcharts.com/designer.html", "100%", "600px") - - def list_charts(self): - """Get a list of avaliable charts.""" - for key, _ in self.charts.items(): - print(key) - - def get_chart_details(self, chart_name): - """ - Get description for a chart. - - Parameters - ---------- - chart_name: str: - The name of the chart you get description for. - - """ - try: - print( - chart_name, - ":", - "\n", - self.charts[chart_name]["Description"], - "\n", - "Query: ", - self.charts[chart_name]["Query"], - ) - except KeyError as key_err: - raise KeyError(f"Unknown chart {chart_name}") from key_err - - def search_charts(self, keyword): - """ - Search for charts that match a keyword. - - Parameters - ---------- - keyword: str: - The keyword to search charts for. - - """ - for key, value in self.charts.items(): - if keyword.casefold() in [tag.casefold() for tag in value["Tags"]]: - print(key, ":", "\n", value["Description"]) - elif keyword.casefold() in [ - word.casefold() for word in value["Description"].split() - ]: - print(key, ":", "\n", value["Description"]) - else: - print("No matching charts found") - - -def _get_charts(path: str = "morph_charts") -> dict: - """ - Return dictionary of yaml files found in the Morph Charts folder. - - Parameters - ---------- - path : str - The source path to search in. - - Returns - ------- - Dict - Details of the chart files - - """ - full_path = Path(__file__).parent.parent.joinpath("resources").joinpath(path) - file_glob = Path(full_path).glob("*.yaml") - chart_files = [file_path for file_path in file_glob if file_path.is_file()] - chart_details = {} - for chart in chart_files: - with open(chart, "r", encoding="utf-8") as chart_data: - details = yaml.safe_load(chart_data) - try: - chart_details.update( - { - details["Name"]: { - "Description": details["Description"], - "Query": details["Query"], - "Tags": details["Tags"], - "DescriptionFile": details["DescriptionFile"], - } - } - ) - except KeyError as key_err: - raise LookupError( - f"{chart} description does not appear to be in the correct format." - ) from key_err - - return chart_details diff --git a/msticpy/vis/mp_pandas_plot.py b/msticpy/vis/mp_pandas_plot.py index f3586744a..60fc7f3c2 100644 --- a/msticpy/vis/mp_pandas_plot.py +++ b/msticpy/vis/mp_pandas_plot.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" -from typing import Iterable, Optional, Tuple, Union + +from collections.abc import Iterable import pandas as pd from bokeh.models import LayoutDOM @@ -178,17 +179,15 @@ def timeline_values(self, value_column: str = None, **kwargs) -> LayoutDOM: The bokeh plot figure. """ - return display_timeline_values( - data=self._df, value_column=value_column, **kwargs - ) + return display_timeline_values(data=self._df, value_column=value_column, **kwargs) def timeline_duration( self, - group_by: Union[Iterable[str], str], + group_by: Iterable[str] | str, time_column: str = "TimeGenerated", - end_time_column: Optional[str] = None, + end_time_column: str | None = None, **kwargs, - ) -> LayoutDOM: # noqa: C901, MC0001 + ) -> LayoutDOM: # noqa: C901 """ Display a duration timeline of events grouped by one or more columns. @@ -248,7 +247,7 @@ def timeline_duration( **kwargs, ) - def process_tree(self, **kwargs) -> Tuple[figure, LayoutDOM]: + def process_tree(self, **kwargs) -> tuple[figure, LayoutDOM]: """ Build and plot a process tree. @@ -541,9 +540,9 @@ def network( source_col: str, target_col: str, title: str = "Data Graph", - source_attrs: Optional[Iterable[str]] = None, - target_attrs: Optional[Iterable[str]] = None, - edge_attrs: Optional[Iterable[str]] = None, + source_attrs: Iterable[str] | None = None, + target_attrs: Iterable[str] | None = None, + edge_attrs: Iterable[str] | None = None, graph_type: GraphType = "graph", **kwargs, ): diff --git a/msticpy/vis/nbdisplay.py b/msticpy/vis/nbdisplay.py index 1fea96af8..a2a63b9ec 100644 --- a/msticpy/vis/nbdisplay.py +++ b/msticpy/vis/nbdisplay.py @@ -4,7 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Module for common display functions.""" -from typing import Any, List, Mapping, Tuple, Union + +from collections.abc import Mapping +from typing import Any import IPython import networkx as nx @@ -28,9 +30,7 @@ @export -def display_alert( - alert: Union[Mapping[str, Any], SecurityAlert], show_entities: bool = False -): +def display_alert(alert: Mapping[str, Any] | SecurityAlert, show_entities: bool = False): """ Display a Security Alert. @@ -52,8 +52,8 @@ def display_alert( @export def format_alert( - alert: Union[Mapping[str, Any], SecurityAlert], show_entities: bool = False -) -> Union[IPython.display.HTML, Tuple[IPython.display.HTML, pd.DataFrame]]: + alert: Mapping[str, Any] | SecurityAlert, show_entities: bool = False +) -> IPython.display.HTML | tuple[IPython.display.HTML, pd.DataFrame]: """ Get IPython displayable Security Alert. @@ -190,7 +190,7 @@ def display_logon_data( @export def format_logon( - logon_event: Union[pd.DataFrame, pd.Series], + logon_event: pd.DataFrame | pd.Series, alert: SecurityAlert = None, os_family: str = None, ) -> IPython.display.HTML: @@ -243,7 +243,7 @@ def format_logon( return HTML(f"{t_style}{''.join(logon_output)}
") -def _fmt_single_row(logon_row: pd.Series, os_family: str) -> List[str]: +def _fmt_single_row(logon_row: pd.Series, os_family: str) -> list[str]: """Format a pandas series logon record.""" logon_record = [ f"Account: {logon_row['TargetUserName']}", @@ -257,8 +257,7 @@ def _fmt_single_row(logon_row: pd.Series, os_family: str) -> List[str]: if logon_type not in _WIN_LOGON_TYPE_MAP: logon_desc_idx = 0 logon_record.append( - f"Logon type: {logon_type}" - + f"({_WIN_LOGON_TYPE_MAP[logon_desc_idx]})" + f"Logon type: {logon_type}" + f"({_WIN_LOGON_TYPE_MAP[logon_desc_idx]})" ) account_id = logon_row.TargetUserSid @@ -282,9 +281,7 @@ def _fmt_single_row(logon_row: pd.Series, os_family: str) -> List[str]: logon_record.append(f"Subject (source) account: {subj_account}") logon_record.append(f"Logon process: {logon_row['LogonProcessName']}") - logon_record.append( - f"Authentication: {logon_row['AuthenticationPackageName']}" - ) + logon_record.append(f"Authentication: {logon_row['AuthenticationPackageName']}") logon_record.append(f"Source IpAddress: {logon_row['IpAddress']}") logon_record.append(f"Source Host: {logon_row['WorkstationName']}") logon_record.append(f"Logon status: {logon_row['Status']}") diff --git a/msticpy/vis/network_plot.py b/msticpy/vis/network_plot.py index 23f301b6c..e9ae3f332 100644 --- a/msticpy/vis/network_plot.py +++ b/msticpy/vis/network_plot.py @@ -5,8 +5,9 @@ # -------------------------------------------------------------------------- """Module for common display functions.""" +from collections.abc import Callable, Iterable from importlib.metadata import version -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Literal import networkx as nx from bokeh.io import output_notebook @@ -24,8 +25,7 @@ ) from bokeh.palettes import Spectral4 from bokeh.plotting import figure, from_networkx, show -from packaging.version import Version, parse -from typing_extensions import Literal +from packaging.version import Version, parse # pylint: disable=no-name-in-module from .._version import VERSION from .figure_dimension import bokeh_figure @@ -42,9 +42,9 @@ figure = bokeh_figure(figure) # type: ignore[assignment, misc] -GraphLayout = Union[ - Callable[[Any], Dict[str, Tuple[float, float]]], - Literal[ +GraphLayout = ( + Callable[[Any], dict[str, tuple[float, float]]] + | Literal[ "spring", "bipartite", "circular", @@ -55,9 +55,9 @@ "spectral", "spiral", "multi_partite", - ], - Dict[str, Tuple[float, float]], -] + ] + | dict[str, tuple[float, float]] +) # pylint: disable=too-many-arguments, too-many-locals @@ -65,14 +65,14 @@ def plot_nx_graph( nx_graph: nx.Graph, title: str = "Data Graph", node_size: int = 25, - font_size: Union[int, str] = 10, + font_size: int | str = 10, height: int = 800, width: int = 800, scale: int = 2, hide: bool = False, - source_attrs: Optional[Iterable[str]] = None, - target_attrs: Optional[Iterable[str]] = None, - edge_attrs: Optional[Iterable[str]] = None, + source_attrs: Iterable[str] | None = None, + target_attrs: Iterable[str] | None = None, + edge_attrs: Iterable[str] | None = None, layout: GraphLayout = "spring", **kwargs, ) -> figure: @@ -137,11 +137,7 @@ def plot_nx_graph( node_attrs = { node: attrs.get( "color", - ( - source_color - if attrs.get("node_role", "source") == "source" - else target_color - ), + (source_color if attrs.get("node_role", "source") == "source" else target_color), ) for node, attrs in nx_graph.nodes(data=True) } @@ -190,9 +186,7 @@ def plot_nx_graph( _create_node_hover(source_attrs, target_attrs, [graph_renderer.node_renderer]) ] if edge_attrs: - hover_tools.append( - _create_edge_hover(edge_attrs, [graph_renderer.edge_renderer]) - ) + hover_tools.append(_create_edge_hover(edge_attrs, [graph_renderer.edge_renderer])) plot.add_tools(*hover_tools, WheelZoomTool(), TapTool(), BoxSelectTool()) # Create labels @@ -225,9 +219,9 @@ def _get_graph_layout(nx_graph: nx.Graph, layout: GraphLayout, **kwargs): def _create_node_hover( - source_attrs: Optional[Iterable[str]], - target_attrs: Optional[Iterable[str]], - renderers: List[Renderer], + source_attrs: Iterable[str] | None, + target_attrs: Iterable[str] | None, + renderers: list[Renderer], ) -> HoverTool: """Create a hover tool for nodes.""" node_attr_cols = set((list(source_attrs or [])) + (list(target_attrs or []))) @@ -238,9 +232,7 @@ def _create_node_hover( return HoverTool(tooltips=node_tooltips, renderers=renderers) -def _create_edge_hover( - edge_attrs: Iterable[str], renderers: List[Renderer] -) -> HoverTool: +def _create_edge_hover(edge_attrs: Iterable[str], renderers: list[Renderer]) -> HoverTool: """Create a hover tool for nodes.""" edge_attr_cols = edge_attrs or [] edge_tooltips = [ @@ -255,9 +247,7 @@ def _create_node_renderer(graph_renderer: Renderer, node_size: int, fill_color: circle_size_param = {"radius": node_size // 2} else: circle_size_param = {"size": node_size // 2} - graph_renderer.node_renderer.glyph = Circle( - **circle_size_param, fill_color=fill_color - ) + graph_renderer.node_renderer.glyph = Circle(**circle_size_param, fill_color=fill_color) graph_renderer.node_renderer.hover_glyph = Circle( **circle_size_param, fill_color=Spectral4[1] ) @@ -268,9 +258,7 @@ def _create_node_renderer(graph_renderer: Renderer, node_size: int, fill_color: def _create_edge_renderer(graph_renderer: Renderer, edge_color: str): """Create graph render for edges.""" - graph_renderer.edge_renderer.hover_glyph = MultiLine( - line_color=Spectral4[1], line_width=5 - ) + graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=5) graph_renderer.edge_renderer.glyph = MultiLine( line_alpha=0.8, line_color=edge_color, line_width=1 ) @@ -282,7 +270,7 @@ def _create_edge_renderer(graph_renderer: Renderer, edge_color: str): def plot_entity_graph( entity_graph: nx.Graph, node_size: int = 25, - font_size: Union[int, str] = 10, + font_size: int | str = 10, height: int = 800, width: int = 800, scale: int = 2, @@ -319,8 +307,7 @@ def plot_entity_graph( output_notebook() font_pnt = f"{font_size}pt" if isinstance(font_size, int) else font_size node_attrs = { - node: attrs.get("color", "green") - for node, attrs in entity_graph.nodes(data=True) + node: attrs.get("color", "green") for node, attrs in entity_graph.nodes(data=True) } nx.set_node_attributes(entity_graph, node_attrs, "node_color") @@ -343,9 +330,7 @@ def plot_entity_graph( ) ) - graph_renderer = from_networkx( - entity_graph, nx.spring_layout, scale=scale, center=(0, 0) - ) + graph_renderer = from_networkx(entity_graph, nx.spring_layout, scale=scale, center=(0, 0)) if _BOKEH_VERSION > Version("3.2.0"): circle_size_param = {"radius": node_size // 2} else: diff --git a/msticpy/vis/process_tree.py b/msticpy/vis/process_tree.py index 88dc79190..764257ea8 100644 --- a/msticpy/vis/process_tree.py +++ b/msticpy/vis/process_tree.py @@ -4,7 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Process Tree Visualization.""" -import warnings + from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import numpy as np @@ -17,7 +17,7 @@ except ImportError: Field = dict # type: ignore from bokeh.layouts import column, row -from bokeh.models import ( # type: ignore[attr-defined] +from bokeh.models import ( BoxSelectTool, ColorBar, ColumnDataSource, @@ -33,7 +33,6 @@ from bokeh.palettes import viridis from bokeh.plotting import figure from bokeh.transform import dodge, factor_cmap, linear_cmap -from deprecated.sphinx import deprecated from .._version import VERSION from ..common.utility import check_kwargs, export @@ -159,7 +158,7 @@ def build_and_show_process_tree( # pylint: disable=too-many-locals, too-many-statements @export -def plot_process_tree( # noqa: MC0001 +def plot_process_tree( data: pd.DataFrame, schema: Union[ProcSchema, Dict[str, Any]] = None, output_var: str = None, @@ -264,7 +263,7 @@ def plot_process_tree( # noqa: MC0001 b_plot.add_tools(hover) # dodge to align rectangle with grid - rect_x = dodge("Level", 1.75, range=b_plot.x_range) # type: ignore + rect_x = dodge("Level", 1.75, range=b_plot.x_range) rect_plot_params = { "width": 3.5, "height": 0.95, @@ -338,11 +337,11 @@ def y_dodge(y_offset): y_col="Row", fill_map=fill_map, ) - plot_elems: LayoutDOM = row(b_plot, range_tool) # type: ignore + plot_elems: LayoutDOM = row(b_plot, range_tool) if show_table: data_table = _create_data_table(source, schema, legend_col) plot_elems = column(plot_elems, data_table) - show(plot_elems) # type: ignore + show(plot_elems) return b_plot, plot_elems @@ -393,13 +392,11 @@ def _pre_process_tree( levels = proc_tree["Level"].unique() proc_tree[schema.process_name] = proc_tree[schema.process_name].fillna("unknown") - proc_tree["__proc_name$$"] = proc_tree.apply( # type: ignore + proc_tree["__proc_name$$"] = proc_tree.apply( lambda x: x[schema.process_name].split(schema.path_separator)[-1], axis=1 ) proc_tree[schema.process_id] = proc_tree[schema.process_id].fillna("unknown") - proc_tree["__proc_id$$"] = proc_tree[schema.process_id].apply( - _pid_fmt, args=(pid_fmt,) - ) + proc_tree["__proc_id$$"] = proc_tree[schema.process_id].apply(_pid_fmt, args=(pid_fmt,)) # Command line processing if not schema.cmd_line: @@ -414,9 +411,9 @@ def _pre_process_tree( proc_tree[long_cmd][schema.cmd_line].str[:max_cmd_len] + "..." ) # replace missing cmd lines - proc_tree.loc[~long_cmd, "__cmd_line$$"] = proc_tree[~long_cmd][ - schema.cmd_line - ].fillna("cmdline unknown") + proc_tree.loc[~long_cmd, "__cmd_line$$"] = proc_tree[~long_cmd][schema.cmd_line].fillna( + "cmdline unknown" + ) return TreeResult(proc_tree=proc_tree, schema=schema, levels=levels, n_rows=n_rows) @@ -428,14 +425,18 @@ def _pid_fmt(pid, pid_fmt): return ( f"PID: {pid}" if str(pid).startswith("0x") - else f"PID: 0x{int(pid):x}" if isinstance(pid, int) else "NA" + else f"PID: 0x{int(pid):x}" + if isinstance(pid, int) + else "NA" ) if pid_fmt == "guid": return f"GUID: {pid}" return ( f"PID: {pid}" if not str(pid).startswith("0x") - else f"PID: {int(pid, base=16)}" if isinstance(pid, int) else "NA" + else f"PID: {int(pid, base=16)}" + if isinstance(pid, int) + else "NA" ) @@ -518,7 +519,7 @@ def _create_fill_map( key_column, palette=viridis(max(3, len(values))), factors=values ) elif col_kind in ["i", "u", "f", "M"]: - values = [val for val in source.data[key_column] if not np.isnan(val)] # type: ignore + values = [val for val in source.data[key_column] if not np.isnan(val)] fill_map = linear_cmap( field_name=key_column, palette=viridis(256), @@ -528,7 +529,9 @@ def _create_fill_map( if source_column is not None: # If user hasn't specified a legend column - don't create a bar color_bar = ColorBar( - color_mapper=fill_map["transform"], width=8, location=(0, 0) # type: ignore + color_mapper=fill_map["transform"], + width=8, + location=(0, 0), ) return fill_map, color_bar @@ -573,9 +576,7 @@ def _create_vert_range_tool( # pylint: enable=too-many-arguments -def _create_data_table( - source: ColumnDataSource, schema: ProcSchema, legend_col: str = None -): +def _create_data_table(source: ColumnDataSource, schema: ProcSchema, legend_col: str = None): """Return DataTable widget for source.""" column_names = [ schema.user_name, @@ -601,9 +602,7 @@ def _create_data_table( ) ] columns2 = [ - TableColumn(field=col, title=col) - for col in column_names - if col in source.column_names + TableColumn(field=col, title=col) for col in column_names if col in source.column_names ] return DataTable(source=source, columns=columns + columns2, width=950, height=150) @@ -615,103 +614,3 @@ def _check_proc_tree_schema(data): return {Col.proc_key} expected_cols = {Col.parent_key, "IsRoot", "IsLeaf", "IsBranch", "path"} return expected_cols - set(data.columns) - - -# pylint: disable=too-few-public-methods -@deprecated("Will be removed in version 2.0.0", version="1.7.0") -@pd.api.extensions.register_dataframe_accessor("mp_process_tree") -class ProcessTreeAccessor: - """Pandas api extension for Process Tree.""" - - def __init__(self, pandas_obj): - """Instantiate pandas extension class.""" - self._df = pandas_obj - - def plot(self, **kwargs) -> Tuple[figure, LayoutDOM]: - """ - Build and plot a process tree. - - Parameters - ---------- - schema : ProcSchema, optional - The data schema to use for the data set, by default None - (if None the schema is inferred) - output_var : str, optional - Output variable for selected items in the tree, - by default None - legend_col : str, optional - The column used to color the tree items, by default None - show_table: bool - Set to True to show a data table, by default False. - - Other Parameters - ---------------- - height : int, optional - The height of the plot figure - (the default is 700) - width : int, optional - The width of the plot figure (the default is 900) - title : str, optional - Title to display (the default is None) - hide_legend : bool, optional - Hide the legend box, even if legend_col is specified. - pid_fmt : str, optional - Display Process ID as 'dec' (decimal), 'hex' (hexadecimal), - or 'guid' (string), default is 'hex'. - - Returns - ------- - Tuple[figure, LayoutDOM]: - figure - The main bokeh.plotting.figure - Layout - Bokeh layout structure. - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp_plot.process_tree() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return build_and_show_process_tree(data=self._df, **kwargs) - - def build(self, schema: ProcSchema = None, **kwargs) -> pd.DataFrame: - """ - Build process trees from the process events. - - Parameters - ---------- - procs : pd.DataFrame - Process events (Windows 4688 or Linux Auditd) - schema : ProcSchema, optional - The column schema to use, by default None - If None, then the schema is inferred - show_summary : bool - Shows summary of the built tree, default is False. : bool - debug : bool - If True produces extra debugging output, - by default False - - Returns - ------- - pd.DataFrame - Process tree dataframe. - - Notes - ----- - It is not necessary to call this before `plot`. The process - tree is built automatically. This is only needed if you want - to return the processed tree data as a DataFrame - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp.build_process_tree() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return build_process_tree( - procs=self._df, - schema=schema, - show_summary=kwargs.get("show_summary", kwargs.get("show_progress", False)), - debug=kwargs.get("debug", False), - ) diff --git a/msticpy/vis/query_browser.py b/msticpy/vis/query_browser.py index 9e0cccbac..ce64f3d81 100644 --- a/msticpy/vis/query_browser.py +++ b/msticpy/vis/query_browser.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """QueryProvider Query Browser.""" + import textwrap -from typing import Any, Generator +from collections.abc import Generator +from typing import Any from IPython.display import HTML diff --git a/msticpy/vis/ti_browser.py b/msticpy/vis/ti_browser.py index 697c5089f..a5aadef97 100644 --- a/msticpy/vis/ti_browser.py +++ b/msticpy/vis/ti_browser.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Threat Intel Results Browser.""" + import pprint -from typing import List, Union import pandas as pd from IPython.display import HTML @@ -19,7 +19,7 @@ def browse_results( data: pd.DataFrame, - severities: Union[List[str], str, None] = None, + severities: list[str] | str | None = None, *, height: str = "300px", ) -> SelectItem: @@ -55,9 +55,7 @@ def browse_results( return SelectItem(item_dict=opts, action=disp_func, height=height) -def get_ti_select_options( - ti_data: pd.DataFrame, severities: Union[List[str], str, None] = None -): +def get_ti_select_options(ti_data: pd.DataFrame, severities: list[str] | str | None = None): """Get SelectItem options for TI data.""" ti_agg_df = _create_ti_agg_list(ti_data, severities) return dict( @@ -74,9 +72,7 @@ def get_ti_select_options( ) -def _create_ti_agg_list( - ti_data: pd.DataFrame, severities: Union[List[str], str, None] = None -): +def _create_ti_agg_list(ti_data: pd.DataFrame, severities: list[str] | str | None = None): """Aggregate ti results on IoC for multiple providers.""" if not severities: severities = ["warning", "high"] @@ -88,14 +84,10 @@ def _create_ti_agg_list( ti_data[ti_data["Severity"].isin(severities)] .groupby(["Ioc", "IocType", "Severity"]) .agg( - Providers=pd.NamedAgg( - column="Provider", aggfunc=lambda x: x.unique().tolist() - ), + Providers=pd.NamedAgg(column="Provider", aggfunc=lambda x: x.unique().tolist()), Details=pd.NamedAgg(column="Details", aggfunc=lambda x: x.tolist()), Responses=pd.NamedAgg(column="RawResult", aggfunc=lambda x: x.tolist()), - References=pd.NamedAgg( - column="Reference", aggfunc=lambda x: x.unique().tolist() - ), + References=pd.NamedAgg(column="Reference", aggfunc=lambda x: x.unique().tolist()), ) .reset_index() ) @@ -105,9 +97,7 @@ def _label_col_dict(row: pd.Series, column: str): """Add label from the Provider column to the details.""" if not isinstance(row[column], dict): return row[column] - return ( - {row.Provider: row[column]} if row.Provider not in row[column] else row[column] - ) + return {row.Provider: row[column]} if row.Provider not in row[column] else row[column] def ti_details_display(ti_data): @@ -120,9 +110,9 @@ def get_ti_details(ioc_prov): h3_style = "background-color: SteelBlue; padding: 6px" results = [f"

{ioc}

"] for prov in provs: - ioc_match = ti_data[ - (ti_data["Ioc"] == ioc) & (ti_data["Provider"] == prov) - ].iloc[0] + ioc_match = ti_data[(ti_data["Ioc"] == ioc) & (ti_data["Provider"] == prov)].iloc[ + 0 + ] results.extend( ( f"

Type: '{ioc_match.IocType}', Provider: {prov}, " @@ -149,9 +139,7 @@ def get_ti_details(ioc_prov): def raw_results(raw_result: str) -> str: """Create pre-formatted details for raw results.""" - fmt_details = ( - pprint.pformat(raw_result).replace("\n", "
").replace(" ", " ") - ) + fmt_details = pprint.pformat(raw_result).replace("\n", "
").replace(" ", " ") return f"""
Raw results from provider... diff --git a/msticpy/vis/timeline.py b/msticpy/vis/timeline.py index 7412e2463..3a16228dd 100644 --- a/msticpy/vis/timeline.py +++ b/msticpy/vis/timeline.py @@ -4,8 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Timeline base plot.""" + +from collections.abc import Iterable from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any import attr import pandas as pd @@ -35,7 +37,6 @@ # pylint: disable=unused-import # Importing to activate pandas accessors -from .timeline_pd_accessor import TimeLineAccessor # noqa F401 from .timeline_values import display_timeline_values # noqa F401 # pylint: enable=unused-import @@ -55,34 +56,34 @@ class PlotParams: """Plot params for time_duration.""" time_column: str = "TimeGenerated" - height: Optional[int] = None + height: int | None = None width: int = 900 title: str = "Events" yaxis: bool = True range_tool: bool = True - group_by: Optional[str] = None - legend: Optional[str] = None + group_by: str | None = None + legend: str | None = None xgrid: bool = True ygrid: bool = False hide: bool = False color: str = "navy" size: int = 10 ylabel_cols: Iterable[str] = attr.Factory(list) - ref_event: Optional[Any] = None - ref_time: Optional[datetime] = None - ref_events: Optional[pd.DataFrame] = None - ref_col: Optional[str] = None - ref_time_col: Optional[str] = None - ref_times: Optional[List[Tuple[datetime, str]]] = None + ref_event: Any | None = None + ref_time: datetime | None = None + ref_events: pd.DataFrame | None = None + ref_col: str | None = None + ref_time_col: str | None = None + ref_times: list[tuple[datetime, str]] | None = None ref_label: str = "Ref time" - source_columns: List[str] = [] + source_columns: list[str] = [] alert: Any = None - overlay_color: Optional[str] = None - overlay_data: Optional[pd.DataFrame] = None + overlay_color: str | None = None + overlay_data: pd.DataFrame | None = None overlay_columns: Iterable[str] = attr.Factory(list) @classmethod - def field_list(cls) -> List[str]: + def field_list(cls) -> list[str]: """Return field names as a list.""" return list(attr.fields_dict(cls).keys()) @@ -94,9 +95,9 @@ def fmt_title(self): @export def display_timeline( - data: Union[pd.DataFrame, dict], + data: pd.DataFrame | dict, time_column: str = "TimeGenerated", - source_columns: Optional[List[str]] = None, + source_columns: list[str] | None = None, **kwargs, ) -> LayoutDOM: """ @@ -194,9 +195,7 @@ def display_timeline( """ # Get args check_kwargs(kwargs, PlotParams.field_list()) - param = PlotParams( - time_column=time_column, source_columns=source_columns or [], **kwargs - ) + param = PlotParams(time_column=time_column, source_columns=source_columns or [], **kwargs) param.ref_time, param.ref_label = get_ref_event_time(**kwargs) if isinstance(data, pd.DataFrame): @@ -238,9 +237,7 @@ def display_timeline( ) -def _display_timeline_dict( - data: dict, param: PlotParams -) -> LayoutDOM: # noqa: C901, MC0001 +def _display_timeline_dict(data: dict, param: PlotParams) -> LayoutDOM: # noqa: C901 """ Display a timeline of events. @@ -296,14 +293,14 @@ def _display_timeline_dict( width=param.width, ) - set_axes_and_grids(data, plot, param.yaxis, param.ygrid, param.xgrid) # type: ignore + set_axes_and_grids(data, plot, param.yaxis, param.ygrid, param.xgrid) # Create plot bar to act as as range selector rng_select = create_range_tool( - data=data, # type: ignore + data=data, min_time=min_time, max_time=max_time, - plot_range=plot.x_range, # type: ignore[arg-type] + plot_range=plot.x_range, width=param.width, height=height, ) @@ -314,7 +311,7 @@ def _display_timeline_dict( _plot_series(data, plot, param.legend) if param.ref_time is not None: - plot_ref_line(plot, param.ref_time, param.ref_label, len(data)) # type: ignore + plot_ref_line(plot, param.ref_time, param.ref_label, len(data)) elif param.ref_events is not None or param.ref_times is not None: plot_ref_events( plot=plot, @@ -342,7 +339,7 @@ def _plot_series(data, plot, legend_pos): legend_items = [] for ser_name, series_def in data.items(): size_param = series_def.get("size", 10) - glyph_size: Union[pd.Series, int] + glyph_size: pd.Series | int if isinstance(size_param, str): if size_param in series_def["data"].columns: glyph_size = series_def["data"][size_param] @@ -402,7 +399,7 @@ def _unpack_data_series_dict(data, param: PlotParams): """Unpack each series from the data series dictionary.""" # Process the input dictionary # Take each item that is passed and fill in blanks and add a y_index - tool_tip_columns: Set[str] = set() + tool_tip_columns: set[str] = set() min_time = None max_time = None y_index = 0 @@ -413,7 +410,7 @@ def _unpack_data_series_dict(data, param: PlotParams): colors, palette_size = get_color_palette(series_count) for ser_name, series_def in data.items(): - data_columns: Set[str] = set() + data_columns: set[str] = set() series_data = series_def["data"] if ( @@ -467,15 +464,13 @@ def _unpack_data_series_dict(data, param: PlotParams): # pylint: enable=too-many-locals -def _create_dict_from_grouping( - data, source_columns, time_column, group_by, color, size=10 -): +def _create_dict_from_grouping(data, source_columns, time_column, group_by, color, size=10): """Return data groupings as a dictionary.""" data_columns = get_def_source_cols(data, source_columns) # If the time column not explicitly specified in source_columns, add it data_columns.add(time_column) - series_dict: Dict[str, Dict] = {} + series_dict: dict[str, dict] = {} # create group frame so that we can color each group separately if group_by: data_columns.add(group_by) diff --git a/msticpy/vis/timeline_common.py b/msticpy/vis/timeline_common.py index e2d5307ea..885bd6bdc 100644 --- a/msticpy/vis/timeline_common.py +++ b/msticpy/vis/timeline_common.py @@ -4,11 +4,13 @@ # license information. # -------------------------------------------------------------------------- """Module for common timeline functions.""" + +from collections.abc import Iterable from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any import pandas as pd -from bokeh.models import ( # type: ignore[attr-defined] +from bokeh.models import ( ColumnDataSource, DatetimeTickFormatter, GestureTool, @@ -26,7 +28,7 @@ from bokeh.plotting import figure try: - from bokeh.plotting import Figure # type: ignore + from bokeh.plotting import Figure except ImportError: Figure = LayoutDOM from pandas.api.types import is_datetime64_any_dtype @@ -56,7 +58,7 @@ @export def check_df_columns( - data: pd.DataFrame, req_columns: List[str], help_uri: str, plot_type: str + data: pd.DataFrame, req_columns: list[str], help_uri: str, plot_type: str ): """ Check that specified columns are in the DataFrame. @@ -78,7 +80,7 @@ def check_df_columns( If one or more columns not found in `data` """ - missing_cols = set(req_columns) - set(data.columns) # type: ignore + missing_cols = set(req_columns) - set(data.columns) if missing_cols: raise MsticpyParameterError( title="Columns not found in DataFrame", @@ -89,11 +91,11 @@ def check_df_columns( def create_data_grouping( data: pd.DataFrame, - source_columns: List[str], + source_columns: list[str], time_column: str, - group_by: Optional[str], + group_by: str | None, color: str, -) -> Tuple[pd.DataFrame, pd.DataFrame, Set[str], int]: +) -> tuple[pd.DataFrame, pd.DataFrame, set[str], int]: """ Group input data and add indexes and tooltips. @@ -148,25 +150,22 @@ def create_data_grouping( graph_df["y_index"] = 1 series_count = 1 group_count_df = None - return graph_df, group_count_df, tool_tip_columns, series_count # type: ignore + return graph_df, group_count_df, tool_tip_columns, series_count -def get_def_source_cols(data: pd.DataFrame, source_columns: Iterable[str]) -> Set[str]: +def get_def_source_cols(data: pd.DataFrame, source_columns: Iterable[str]) -> set[str]: """Get default set of columns (backward compat).""" if not source_columns: return ( {"NewProcessName", "EventID", "CommandLine"} - if all( - col in data.columns - for col in ["NewProcessName", "EventID", "CommandLine"] - ) + if all(col in data.columns for col in ["NewProcessName", "EventID", "CommandLine"]) else set() ) return set(source_columns) -def get_color_palette(series_count: int) -> Tuple[Palette, int]: +def get_color_palette(series_count: int) -> tuple[Palette, int]: """Return palette based on series size.""" palette_size = min(256, series_count + series_count // 5) return viridis(palette_size), palette_size @@ -196,32 +195,32 @@ def set_axes_and_grids( def get_time_bounds( min_time: pd.Timestamp, max_time: pd.Timestamp -) -> Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]: +) -> tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]: """Return start and end range, coping with out-of-bounds error.""" try: start_range = min_time - ((max_time - min_time) * 0.1) end_range = max_time + ((max_time - min_time) * 0.1) except OutOfBoundsDatetime: - min_time = min_time.to_pydatetime() # type: ignore - max_time = max_time.to_pydatetime() # type: ignore + min_time = min_time.to_pydatetime() + max_time = max_time.to_pydatetime() start_range = min_time - ((max_time - min_time) * 0.1) end_range = max_time + ((max_time - min_time) * 0.1) return start_range, end_range, min_time, max_time def create_tool_tips( - data: Union[pd.DataFrame, Dict[str, pd.DataFrame]], columns: Iterable[str] -) -> Dict[str, Any]: + data: pd.DataFrame | dict[str, pd.DataFrame], columns: Iterable[str] +) -> dict[str, Any]: """Create formatting for tool tip columns.""" - formatters: Dict[str, str] = {} + formatters: dict[str, str] = {} # if this is a dict we need to unpack each dataframe and process # the tooltip columns for all of the data sets. if isinstance(data, dict): tool_tip_dict = {} for data_set in data.values(): - data_df = data_set.get("data", {}) # type: ignore + data_df = data_set.get("data", {}) for col in columns: - disp_col, col_tooltip, col_fmt = _get_datetime_tooltip(col, data_df) # type: ignore + disp_col, col_tooltip, col_fmt = _get_datetime_tooltip(col, data_df) tool_tip_dict[disp_col] = col_tooltip formatters.update(col_fmt) return {"tooltips": list(tool_tip_dict.items()), "formatters": formatters} @@ -236,9 +235,7 @@ def create_tool_tips( return {"tooltips": tool_tip_items, "formatters": formatters} -def _get_datetime_tooltip( - col: str, dataset: pd.DataFrame -) -> Tuple[str, str, Dict[str, str]]: +def _get_datetime_tooltip(col: str, dataset: pd.DataFrame) -> tuple[str, str, dict[str, str]]: """Return tooltip and formatter entries for column.""" if " " in col: disp_col = col.replace(" ", "_") @@ -247,7 +244,7 @@ def _get_datetime_tooltip( disp_col = tt_col = col if col in dataset and is_datetime64_any_dtype(dataset[col]): col_tooltip = f"@{tt_col}{{%F %T.%3N}}" - col_fmt: Dict[Any, Any] = {f"@{tt_col}": "datetime"} + col_fmt: dict[Any, Any] = {f"@{tt_col}": "datetime"} else: col_tooltip = f"@{tt_col}" col_fmt = {} @@ -290,9 +287,7 @@ def create_range_tool( "Drag the middle or edges of the selection box to change " + "the range in the main chart" ) - rng_select.add_layout( - Title(text=help_str, align="right", text_font_size="10px"), "below" - ) + rng_select.add_layout(Title(text=help_str, align="right", text_font_size="10px"), "below") rng_select.xaxis[0].formatter = get_tick_formatter() if isinstance(data, dict): for _, series_def in data.items(): @@ -309,12 +304,12 @@ def create_range_tool( ) range_tool = RangeTool(x_range=plot_range) - range_tool.overlay.fill_color = "navy" # type: ignore - range_tool.overlay.fill_alpha = 0.2 # type: ignore + range_tool.overlay.fill_color = "navy" + range_tool.overlay.fill_alpha = 0.2 rng_select.ygrid.grid_line_color = None rng_select.add_tools(range_tool) if isinstance(range_tool, GestureTool): - rng_select.toolbar.active_multi = range_tool # type: ignore + rng_select.toolbar.active_multi = range_tool return rng_select @@ -361,9 +356,9 @@ def plot_ref_events( plot: Figure, time_col: str, group_count: int, - ref_events: Optional[pd.DataFrame] = None, - ref_col: Optional[str] = None, - ref_times: Optional[List[Tuple[datetime, str]]] = None, + ref_events: pd.DataFrame | None = None, + ref_col: str | None = None, + ref_times: list[tuple[datetime, str]] | None = None, ): """Plot reference lines/labels.""" if ref_events is not None: @@ -371,9 +366,7 @@ def plot_ref_events( ref_events = pd.DataFrame(ref_events) for idx, event in enumerate(ref_events.itertuples()): evt_time = event._asdict()[time_col] - evt_label = ( - event._asdict()[ref_col] if ref_col else f"reference {event.Index}" - ) + evt_label = event._asdict()[ref_col] if ref_col else f"reference {event.Index}" plot_ref_line( plot=plot, ref_time=evt_time, @@ -393,7 +386,7 @@ def plot_ref_events( ) -def get_ref_event_time(**kwargs) -> Tuple[Optional[Any], Union[Any, str]]: +def get_ref_event_time(**kwargs) -> tuple[Any | None, Any | str]: """Extract the reference time from kwargs.""" ref_alert = kwargs.get("alert", None) if ref_alert is not None: @@ -412,16 +405,16 @@ def get_ref_event_time(**kwargs) -> Tuple[Optional[Any], Union[Any, str]]: else: ref_time = kwargs.get("ref_time", None) ref_label = "Ref time" - return ref_time, kwargs.get("ref_label", ref_label) # type: ignore + return ref_time, kwargs.get("ref_label", ref_label) def get_tick_formatter() -> DatetimeTickFormatter: """Return tick formatting for different zoom levels.""" # '%H:%M:%S.%3Nms tick_format = DatetimeTickFormatter() - tick_format.days = "%m-%d %H:%M" # type: ignore - tick_format.hours = "%H:%M:%S" # type: ignore - tick_format.minutes = "%H:%M:%S" # type: ignore - tick_format.seconds = "%H:%M:%S" # type: ignore - tick_format.milliseconds = "%H:%M:%S.%3N" # type: ignore + tick_format.days = "%m-%d %H:%M" # type: ignore[assignment] + tick_format.hours = "%H:%M:%S" # type: ignore[assignment] + tick_format.minutes = "%H:%M:%S" # type: ignore[assignment] + tick_format.seconds = "%H:%M:%S" # type: ignore[assignment] + tick_format.milliseconds = "%H:%M:%S.%3N" # type: ignore[assignment] return tick_format diff --git a/msticpy/vis/timeline_duration.py b/msticpy/vis/timeline_duration.py index fd9cb141f..024500fba 100644 --- a/msticpy/vis/timeline_duration.py +++ b/msticpy/vis/timeline_duration.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Timeline duration plot.""" + +from collections.abc import Iterable from datetime import datetime -from typing import Iterable, List, Optional, Tuple, Union import attr import pandas as pd @@ -32,12 +33,6 @@ set_axes_and_grids, ) -# pylint: disable=unused-import -# Importing to activate pandas accessors -from .timeline_pd_accessor import TimeLineAccessor # noqa F401 - -# pylint: enable=unused-import - __version__ = VERSION __author__ = "Ian Hellen" @@ -57,9 +52,9 @@ class PlotParams: """Plot params for time_duration.""" - height: Optional[int] = None + height: int | None = None width: int = 900 - title: Optional[str] = None + title: str | None = None yaxis: bool = True range_tool: bool = True xgrid: bool = True @@ -67,13 +62,13 @@ class PlotParams: hide: bool = False color: str = "navy" ylabel_cols: Iterable[str] = attr.Factory(list) - ref_events: Optional[pd.DataFrame] = None - ref_col: Optional[str] = None - ref_times: Optional[List[Tuple[datetime, str]]] = None - source_columns: List = [] + ref_events: pd.DataFrame | None = None + ref_col: str | None = None + ref_times: list[tuple[datetime, str]] | None = None + source_columns: list = [] @classmethod - def field_list(cls) -> List[str]: + def field_list(cls) -> list[str]: """Return field names as a list.""" return list(attr.fields_dict(cls).keys()) @@ -84,11 +79,11 @@ def field_list(cls) -> List[str]: @export def display_timeline_duration( data: pd.DataFrame, - group_by: Union[Iterable[str], str], + group_by: Iterable[str] | str, time_column: str = "TimeGenerated", - end_time_column: Optional[str] = None, + end_time_column: str | None = None, **kwargs, -) -> LayoutDOM: # noqa: C901, MC0001 +) -> LayoutDOM: # noqa: C901 """ Display a duration timeline of events grouped by one or more columns. @@ -150,7 +145,7 @@ def display_timeline_duration( group_by = [group_by] if isinstance(group_by, str) else list(group_by) end_time_column = end_time_column or time_column - data = ensure_df_datetimes(data, columns=list(set([time_column, end_time_column]))) + data = ensure_df_datetimes(data, columns=list({time_column, end_time_column})) check_df_columns( data, group_by + [end_time_column, time_column], @@ -183,9 +178,7 @@ def display_timeline_duration( height = param.height or calc_auto_plot_height(len(grouped_data)) # Concatenate ylabel columns to display on y-axis if len(group_by) > 1: - y_range = grouped_data[group_by[0]].str.cat( - grouped_data[group_by[1:]], sep=" / " - ) + y_range = grouped_data[group_by[0]].str.cat(grouped_data[group_by[1:]], sep=" / ") else: y_range = grouped_data[group_by[0]] @@ -224,14 +217,14 @@ def display_timeline_duration( ) # Set grid parameters - set_axes_and_grids(None, plot, param.yaxis, param.ygrid, param.xgrid) # type: ignore + set_axes_and_grids(None, plot, param.yaxis, param.ygrid, param.xgrid) # Create plot bar to act as as range selector rng_select = create_range_tool( data=all_data, min_time=min_time, max_time=max_time, - plot_range=plot.x_range, # type: ignore[arg-type] + plot_range=plot.x_range, width=param.width, height=height, time_column=time_column, @@ -257,7 +250,7 @@ def display_timeline_duration( def _group_durations( - data: pd.DataFrame, group_by: List[str], time_column: str, end_time_column: str + data: pd.DataFrame, group_by: list[str], time_column: str, end_time_column: str ): """Group the data and calculate start and end times.""" grouped_data = data.groupby(group_by).agg( diff --git a/msticpy/vis/timeline_pd_accessor.py b/msticpy/vis/timeline_pd_accessor.py deleted file mode 100644 index 5769b39d0..000000000 --- a/msticpy/vis/timeline_pd_accessor.py +++ /dev/null @@ -1,266 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Pandas accessor class for timeline functions.""" -import warnings -from typing import Iterable, Optional, Union - -import pandas as pd -from bokeh.models import LayoutDOM -from deprecated.sphinx import deprecated - -from .._version import VERSION - -__version__ = VERSION -__author__ = "Ian Hellen" - -# pylint: disable=import-outside-toplevel, cyclic-import - - -@deprecated("Will be removed in version 2.2.0", version="1.7.0") -@pd.api.extensions.register_dataframe_accessor("mp_timeline") -class TimeLineAccessor: - """Pandas api extension for Timeline.""" - - def __init__(self, pandas_obj): - """Instantiate pandas extension class.""" - from .timeline import display_timeline, display_timeline_values - from .timeline_duration import display_timeline_duration - - self._display_timeline = display_timeline - self._display_timeline_values = display_timeline_values - self._display_timeline_duration = display_timeline_duration - self._df = pandas_obj - - def plot(self, **kwargs) -> LayoutDOM: - """ - Display a timeline of events. - - Parameters - ---------- - time_column : str, optional - Name of the timestamp column - (the default is 'TimeGenerated') - source_columns : list, optional - List of default source columns to use in tooltips - (the default is None) - - Other Parameters - ---------------- - title : str, optional - Title to display (the default is None) - alert : SecurityAlert, optional - Add a reference line/label using the alert time (the default is None) - ref_event : Any, optional - Add a reference line/label using the alert time (the default is None) - ref_time : datetime, optional - Add a reference line/label using `ref_time` (the default is None) - group_by : str - The column to group timelines on. - legend: str, optional - "left", "right", "inline" or "none" - (the default is to show a legend when plotting multiple series - and not to show one when plotting a single series) - yaxis : bool, optional - Whether to show the yaxis and labels (default is False) - ygrid : bool, optional - Whether to show the yaxis grid (default is False) - xgrid : bool, optional - Whether to show the xaxis grid (default is True) - range_tool : bool, optional - Show the the range slider tool (default is True) - height : int, optional - The height of the plot figure - (the default is auto-calculated height) - width : int, optional - The width of the plot figure (the default is 900) - color : str - Default series color (default is "navy") - overlay_data : pd.DataFrame: - A second dataframe to plot as a different series. - overlay_color : str - Overlay series color (default is "green") - ref_events : pd.DataFrame, optional - Add references line/label using the event times in the dataframe. - (the default is None) - ref_time_col : str, optional - Add references line/label using the this column in `ref_events` - for the time value (x-axis). - (this defaults the value of the `time_column` parameter or 'TimeGenerated' - `time_column` is None) - ref_col : str, optional - The column name to use for the label from `ref_events` - (the default is None) - ref_times : List[Tuple[datetime, str]], optional - Add one or more reference line/label using (the default is None) - - Returns - ------- - LayoutDOM - The bokeh plot figure. - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp_plot.timeline() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return self._display_timeline(data=self._df, **kwargs) - - def plot_values(self, value_column: str = None, **kwargs) -> LayoutDOM: - """ - Display a timeline of events. - - Parameters - ---------- - time_column : str, optional - Name of the timestamp column - (the default is 'TimeGenerated') - value_column : str - The column name holding the value to plot vertically - source_columns : list, optional - List of default source columns to use in tooltips - (the default is None) - - Other Parameters - ---------------- - x : str, optional - alias of `time_column` - y : str, optional - alias of `value_column` - value_col : str, optional - alias of `value_column` - title : str, optional - Title to display (the default is None) - ref_event : Any, optional - Add a reference line/label using the alert time (the default is None) - ref_time : datetime, optional - Add a reference line/label using `ref_time` (the default is None) - ref_label : str, optional - A label for the `ref_event` or `ref_time` reference item - group_by : str - (where `data` is a DataFrame) - The column to group timelines on - legend: str, optional - "left", "right", "inline" or "none" - (the default is to show a legend when plotting multiple series - and not to show one when plotting a single series) - yaxis : bool, optional - Whether to show the yaxis and labels - range_tool : bool, optional - Show the the range slider tool (default is True) - height : int, optional - The height of the plot figure - (the default is auto-calculated height) - width : int, optional - The width of the plot figure (the default is 900) - color : str - Default series color (default is "navy"). This is overridden by - automatic color assignments if plotting a grouped chart - kind : Union[str, List[str]] - one or more glyph types to plot., optional - Supported types are "circle", "line" and "vbar" (default is "vbar") - ref_events : pd.DataFrame, optional - Add references line/label using the event times in the dataframe. - (the default is None) - ref_time_col : str, optional - Add references line/label using the this column in `ref_events` - for the time value (x-axis). - (this defaults the value of the `time_column` parameter or 'TimeGenerated' - `time_column` is None) - ref_col : str, optional - The column name to use for the label from `ref_events` - (the default is None) - ref_times : List[Tuple[datetime, str]], optional - Add one or more reference line/label using (the default is None) - - Returns - ------- - LayoutDOM - The bokeh plot figure. - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp_plot.timeline_values() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return self._display_timeline_values( - data=self._df, value_column=value_column, **kwargs - ) - - def plot_duration( - self, - group_by: Union[Iterable[str], str], - time_column: str = "TimeGenerated", - end_time_column: Optional[str] = None, - **kwargs, - ) -> LayoutDOM: # noqa: C901, MC0001 - """ - Display a duration timeline of events grouped by one or more columns. - - Parameters - ---------- - group_by : Union[Iterable[str], str] - The column name or iterable of column names to group the data by. - time_column : str - Primary time column - will be used to calculate the - start time of the duration for each group. - If `end_time_column` is not specified it will also be used to - calculate the end time. - end_time_column : Optional[str] - If supplied, it will be used to calculate the end time - of the duration for each group. - - Other Parameters - ---------------- - title : str, optional - Title to display (the default is None) - ylabel_cols : Optional[Iterable[str]], optional - The subset of the group columns to use for the y-axis labels. - yaxis : bool, optional - Whether to show the yaxis and labels - range_tool : bool, optional - Show the the range slider tool (default is True) - source_columns : list, optional - List of default source columns to use in tooltips - (the default is None) - height : int, optional - The height of the plot figure - (the default is auto-calculated height) - width : int, optional - The width of the plot figure (the default is 900) - color : str - Default series color (default is "navy") - ref_events : pd.DataFrame, optional - Add references line/label using the event times in the dataframe. - (the default is None) - ref_col : str, optional - The column name to use for the label from `ref_events` - (the default is None) - ref_times : List[Tuple[datetime, str]], optional - Add one or more reference line/label using (the default is None) - - Returns - ------- - LayoutDOM - The bokeh plot figure. - - """ - warn_message = ( - "This accessor method has been deprecated.\n" - "Please use df.mp_plot.timeline_duration() method instead." - "This will be removed in MSTICPy v2.2.0" - ) - warnings.warn(warn_message, category=DeprecationWarning) - return self._display_timeline_duration( - data=self._df, - group_by=group_by, - time_column=time_column, - end_time_column=end_time_column, - **kwargs, - ) diff --git a/msticpy/vis/timeline_values.py b/msticpy/vis/timeline_values.py index f401fb460..897e42392 100644 --- a/msticpy/vis/timeline_values.py +++ b/msticpy/vis/timeline_values.py @@ -4,14 +4,16 @@ # license information. # -------------------------------------------------------------------------- """Timeline values Bokeh plot.""" + +from collections.abc import Iterable from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any import attr import pandas as pd from bokeh.io import output_notebook, show from bokeh.layouts import column -from bokeh.models import ( # type: ignore[attr-defined] +from bokeh.models import ( ColumnDataSource, HoverTool, LayoutDOM, @@ -50,37 +52,37 @@ class PlotParams: """Plot params for time_duration.""" - time_column: Optional[str] = None - height: Optional[int] = None + time_column: str | None = None + height: int | None = None width: int = 900 - title: Optional[str] = None + title: str | None = None yaxis: bool = True range_tool: bool = True - group_by: Optional[str] = None - legend: Optional[str] = None + group_by: str | None = None + legend: str | None = None xgrid: bool = True ygrid: bool = False hide: bool = False color: str = "navy" - kind: Union[str, List[str]] = "vbar" + kind: str | list[str] = "vbar" ylabel_cols: Iterable[str] = attr.Factory(list) - ref_event: Optional[Any] = None - ref_time: Optional[datetime] = None - ref_events: Optional[pd.DataFrame] = None - ref_col: Optional[str] = None - ref_time_col: Optional[str] = None - ref_times: Optional[List[Tuple[datetime, str]]] = None - source_columns: List = [] + ref_event: Any | None = None + ref_time: datetime | None = None + ref_events: pd.DataFrame | None = None + ref_col: str | None = None + ref_time_col: str | None = None + ref_times: list[tuple[datetime, str]] | None = None + source_columns: list = [] @classmethod - def field_list(cls) -> List[str]: + def field_list(cls) -> list[str]: """Return field names as a list.""" return list(attr.fields_dict(cls).keys()) # pylint: disable=invalid-name, too-many-locals, too-many-statements, too-many-branches -@export # noqa: C901, MC0001 -def display_timeline_values( # noqa: C901, MC0001 +@export # noqa: C901 +def display_timeline_values( # noqa: C901, PLR0912, PLR0915 data: pd.DataFrame, value_column: str = None, time_column: str = "TimeGenerated", @@ -251,7 +253,7 @@ def display_timeline_values( # noqa: C901, MC0001 click_policy="hide", label_text_font_size="8pt", ) - plot.add_layout(ext_legend, param.legend) # type: ignore[arg-type] + plot.add_layout(ext_legend, param.legend) else: plot_args = { "x": time_column, @@ -285,7 +287,7 @@ def display_timeline_values( # noqa: C901, MC0001 data=graph_df, min_time=min_time, max_time=max_time, - plot_range=plot.x_range, # type: ignore[arg-type] + plot_range=plot.x_range, width=param.width, height=height, time_column=time_column, @@ -319,9 +321,9 @@ def _plot_param_group( graph_df, group_count_df, plot, -) -> List[Tuple[str, Any]]: +) -> list[tuple[str, Any]]: """Plot series groups.""" - legend_items: List[Tuple[str, Any]] = [] + legend_items: list[tuple[str, Any]] = [] for _, group_id in group_count_df[param.group_by].items(): first_group_item = graph_df[graph_df[param.group_by] == group_id].iloc[0] legend_label = str(first_group_item[param.group_by]) @@ -330,7 +332,7 @@ def _plot_param_group( row_source = ColumnDataSource(graph_df[graph_df[param.group_by] == group_id]) p_series = [] # create default plot args - plot_args: Dict[str, Any] = { + plot_args: dict[str, Any] = { "x": time_column, "alpha": 0.7, "source": row_source, @@ -339,18 +341,12 @@ def _plot_param_group( plot_args["legend_label"] = inline_legend if "vbar" in plot_kinds: - p_series.append( - plot.vbar(top=value_col, width=4, color="color", **plot_args) - ) + p_series.append(plot.vbar(top=value_col, width=4, color="color", **plot_args)) if "circle" in plot_kinds: - p_series.append( - plot.circle(y=value_col, radius=2, color="color", **plot_args) - ) + p_series.append(plot.circle(y=value_col, radius=2, color="color", **plot_args)) if "line" in plot_kinds: p_series.append( - plot.line( - y=value_col, line_width=2, line_color=group_color, **plot_args - ) + plot.line(y=value_col, line_width=2, line_color=group_color, **plot_args) ) if not inline_legend: legend_items.append((legend_label, p_series)) diff --git a/msticpy/vis/timeseries.py b/msticpy/vis/timeseries.py index 07b6dc046..2fc66088f 100644 --- a/msticpy/vis/timeseries.py +++ b/msticpy/vis/timeseries.py @@ -4,8 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Module for common display functions.""" + from itertools import zip_longest -from typing import Any, Dict +from typing import Any import pandas as pd from bokeh.io import output_notebook, show @@ -63,8 +64,8 @@ # pylint: disable=invalid-name, too-many-locals, too-many-statements # pylint: disable=too-many-branches, too-many-function-args, too-many-arguments -@export # noqa: C901, MC0001 -def display_timeseries_anomalies( +@export # noqa: C901 +def display_timeseries_anomalies( # noqa: PLR0915 data: pd.DataFrame, y: str = "Total", time_column: str = "TimeGenerated", @@ -142,8 +143,7 @@ def display_timeseries_anomalies( show_range: bool = kwargs.pop("range_tool", True) color: list = kwargs.get("color", ["navy", "green", "firebrick"]) color = [ - col1 or col2 - for col1, col2 in zip_longest(color[:3], ["navy", "green", "firebrick"]) + col1 or col2 for col1, col2 in zip_longest(color[:3], ["navy", "green", "firebrick"]) ] legend_pos: str = kwargs.pop("legend", "top_left") xgrid: bool = kwargs.pop("xgrid", False) @@ -161,7 +161,7 @@ def display_timeseries_anomalies( source_columns = [col for col in data.columns if col not in [anomalies_column]] data_anomaly = data[data[anomalies_column] == 1][source_columns].reset_index() - hover = HoverTool(**(create_tool_tips(data, source_columns))) # type: ignore + hover = HoverTool(**(create_tool_tips(data, source_columns))) # Create the Plot figure title = title or "Time Series Anomalies Visualization" @@ -222,7 +222,7 @@ def display_timeseries_anomalies( # create default plot args # pylint: disable=use-dict-literal - arg_dict: Dict[str, Any] = { + arg_dict: dict[str, Any] = { "x": time_column, "y": value_column, "size": 12, @@ -252,7 +252,7 @@ def display_timeseries_anomalies( y="score", min_time=min_time, max_time=max_time, - plot_range=plot.x_range, # type: ignore[arg-type] + plot_range=plot.x_range, width=width, height=height, time_column=time_column, diff --git a/msticpy/vis/vtobject_browser.py b/msticpy/vis/vtobject_browser.py index 6e1ff60b9..a56337fb9 100644 --- a/msticpy/vis/vtobject_browser.py +++ b/msticpy/vis/vtobject_browser.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """VirusTotal Object browser.""" + import pprint -from typing import Dict, Optional import ipywidgets as widgets import pandas as pd @@ -36,7 +36,7 @@ class VTObjectBrowser(IPyDisplayMixin): _BASIC_TITLE = "VirusTotal File hash lookup" - def __init__(self, file_id: Optional[str] = None): + def __init__(self, file_id: str | None = None): """ Initialize the VT Browser. @@ -77,9 +77,7 @@ def __init__(self, file_id: Optional[str] = None): self.hb_vt_attribs = widgets.HBox( [self.data_sel, self.data_view], layout=_BORDER_LAYOUT ) - self.layout = widgets.VBox( - [self.html_header, self.hb_file_lookup, self.hb_vt_attribs] - ) + self.layout = widgets.VBox([self.html_header, self.hb_file_lookup, self.hb_vt_attribs]) if file_id: self.btn_lookup.click() @@ -110,7 +108,7 @@ def _lookup_file_id(self, btn): self.data_sel.options = self._current_data.columns -def _extract_summary(data: Optional[pd.DataFrame] = None) -> Dict[str, str]: +def _extract_summary(data: pd.DataFrame | None = None) -> dict[str, str]: """Return summary of item.""" def_dict = {"sha256": "", "meaningful_name": "", "names": "", "magic": ""} if data is None: @@ -124,19 +122,19 @@ def _extract_summary(data: Optional[pd.DataFrame] = None) -> Dict[str, str]: return data[["sha256", "meaningful_name", "names", "magic"]].iloc[0].to_dict() -def _summary_html(title: str, summary: Dict[str, str]) -> str: +def _summary_html(title: str, summary: dict[str, str]) -> str: """Return HTML formatted summary.""" return f"""

{title}

- + - + - +
ID{summary.get('sha256')}ID{summary.get("sha256")}
Names{summary.get('names')}Names{summary.get("names")}
File Type{summary.get('magic')}File Type{summary.get("magic")}
""" diff --git a/pyproject.toml b/pyproject.toml index 4d89708a3..703063d72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,19 +5,73 @@ requires = [ ] build-backend = "setuptools.build_meta" -[tool.isort] -profile = "black" -src_paths = ["msticpy", "tests"] +[tool.ruff] +line-length = 95 +target-version = "py310" -[tool.pydocstyle] -convention = "numpy" +# Exclude directories +exclude = [ + ".git", + "__pycache__", + "docs/source/conf.py", + "build", + "dist", + "tests", + "test*.py", +] [tool.ruff.lint] -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or -# McCabe complexity (`C901`) by default. -select = ["E4", "E7", "E9", "F", "W", "D"] -ignore = ["D212", "D203", "D417"] +# Enable the rule sets you want +select = [ + "E", # pycodestyle errors (flake8) + "W", # pycodestyle warnings (flake8) + "F", # pyflakes (flake8) + "I", # isort + "D", # pydocstyle + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "PL", # pylint (subset) +] + +# Ignore specific rules to match your current config +ignore = [ + "E401", # Multiple imports on one line (from flake8 config) + "E501", # Line too long (handled by formatter) + #"W503", # Line break before binary operator (from flake8 config) + "D212", # Multi-line docstring summary should start at the first line (from pydocstyle) + "D203", # 1 blank line required before class docstring (from pydocstyle) + "D417", # Missing argument descriptions (from pydocstyle) + "PLR0913", # Too many arguments + "PLR2004", # Magic value comparison + "PLW2901", # Loop variable overwritten +] + +[tool.ruff.lint.per-file-ignores] +# Ignore certain rules in test files +"tests/**" = ["D", "PL"] +"test_*.py" = ["D", "PL"] [tool.ruff.lint.pydocstyle] -convention = "numpy" \ No newline at end of file +# Use numpy convention (from your config) +convention = "numpy" + +[tool.ruff.lint.isort] +# Match your isort config +# profile = "black" +known-first-party = ["msticpy"] + +[tool.ruff.lint.pylint] +# Match your pylint config +max-args = 10 +max-branches = 15 + +[tool.ruff.format] +# Use double quotes (Black style) +quote-style = "double" +# Indent with spaces +indent-style = "space" +# Like Black, respect magic trailing comma +skip-magic-trailing-comma = false +# Like Black, automatically detect line endings +line-ending = "auto" \ No newline at end of file diff --git a/requirements-all.txt b/requirements-all.txt index b70c03176..d8c3323ed 100644 --- a/requirements-all.txt +++ b/requirements-all.txt @@ -23,13 +23,10 @@ folium>=0.9.0 geoip2>=2.9.0 httpx>=0.23.0, <1.0.0 html5lib -importlib-resources >= 6.4.0; python_version <= "3.8" -ipython >= 7.1.1; python_version < "3.8" -ipython >= 7.23.1; python_version >= "3.8" +ipython>=7.23.1 ipywidgets>=7.4.2, <9.0.0 jinja2>=3.1.5 # (sec vuln) transitive dependency via multiple packages keyring>=13.2.1 -KqlmagicCustom[jupyter-extended]>=0.1.114.post22 lxml>=4.6.5 matplotlib>=3.0.0 mo-sql-parsing>=11, <12.0.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index bd4d31252..85889d573 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,27 +2,19 @@ aiohttp>=3.7.4 async-cache>=1.1.1 bandit>=1.7.0 beautifulsoup4>=4.0.0 -black>=20.8b1, <25.0.0 coverage>=5.5 docutils<0.22.0 filelock>=3.0.0 -flake8>=3.8.4 httpx>=0.23.0, <0.28.0 -isort>=5.10.1 jsonschema>=4.17.3 markdown>=3.3.4 -mccabe>=0.6.1 mypy>=0.812 nbdime>=2.1.0 nbconvert>=6.1.0 pandas>=1.4.0, <3.0.0 -pep8-naming>=0.10.0 -pep8>=1.7.1 pipreqs>=0.4.9 pre-commit>=2.7.1 -pycodestyle>=2.6.0 pydocstyle>=6.0.0 -pyflakes>=2.2.0 pygeohash>=1.2.0 pylint>=2.5.3 pyroma>=3.1 diff --git a/requirements.txt b/requirements.txt index 0e8bc25f5..6a2b3f1d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,9 +16,7 @@ folium>=0.9.0 geoip2>=2.9.0 httpx>=0.23.0, <1.0.0 html5lib -importlib-resources >= 6.4.0; python_version <= "3.8" -ipython >= 7.1.1; python_version < "3.8" -ipython >= 7.23.1; python_version >= "3.8" +ipython>=7.23.1 ipywidgets>=7.4.2, <9.0.0 jinja2>=3.1.5 # (sec vuln) transitive dependency via multiple packages keyring>=13.2.1 diff --git a/setup.cfg b/setup.cfg index 98676d605..b997a1606 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,11 +25,10 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 License :: OSI Approved :: MIT License Operating System :: OS Independent Development Status :: 5 - Production/Stable @@ -52,7 +51,7 @@ include_package_data = True package_dir = msticpy = msticpy packages = find: -python_requires = >=3.8 +python_requires = >=3.10 [options.packages.find] include = msticpy* diff --git a/setup.py b/setup.py index 43c82e4f3..eebfca367 100644 --- a/setup.py +++ b/setup.py @@ -21,9 +21,7 @@ def _combine_extras(extras: list) -> list: - return list( - {pkg for name, pkgs in EXTRAS.items() for pkg in pkgs if name in extras} - ) + return list({pkg for name, pkgs in EXTRAS.items() for pkg in pkgs if name in extras}) # Extras definitions @@ -32,7 +30,6 @@ def _combine_extras(extras: list) -> list: "vt3": ["vt-py>=0.18.0", "vt-graph-api>=2.0"], "splunk": ["splunk-sdk>=1.6.0,!=2.0.0"], "sumologic": ["sumologic-sdk>=0.1.11", "openpyxl>=3.0"], - "kql": ["KqlmagicCustom[jupyter-extended]>=0.1.114.post22"], "azure": [ "azure-mgmt-compute>=4.6.2", "azure-mgmt-core>=1.2.1", @@ -62,9 +59,7 @@ def _combine_extras(extras: list) -> list: EXTRAS["all"] = extras_all # Create combination extras -EXTRAS["all"] = sorted( - _combine_extras(list({name for name in EXTRAS if name != "dev"})) -) +EXTRAS["all"] = sorted(_combine_extras(list({name for name in EXTRAS if name != "dev"}))) EXTRAS["test"] = sorted(_combine_extras(["all", "dev"])) EXTRAS["sentinel"] = EXTRAS["azure"] diff --git a/tests/config/test_item_editors.py b/tests/config/test_item_editors.py index 17446399b..26bead166 100644 --- a/tests/config/test_item_editors.py +++ b/tests/config/test_item_editors.py @@ -290,6 +290,7 @@ def test_tiproviders_editor(kv_sec, mp_conf_ctrl): @respx.mock +@pytest.mark.filterwarnings("ignore:Use list") @patch("msticpy.config.ce_common.get_token") def test_get_tenant_id(get_token): """Test get tenantID function.""" diff --git a/tests/context/test_geoip.py b/tests/context/test_geoip.py index 81c5e91f4..c6f9a3d8d 100644 --- a/tests/context/test_geoip.py +++ b/tests/context/test_geoip.py @@ -83,6 +83,7 @@ def test_geoiplite_download(tmp_path): tgt_folder.rmdir() +@pytest.mark.filterwarnings("ignore:GeoIpLookup") def test_geoiplite_lookup(): """Test GeoLite lookups.""" ips = ["151.101.128.223", "151.101.0.223", "151.101.64.223", "151.101.192.223"] diff --git a/tests/context/test_ip_utils.py b/tests/context/test_ip_utils.py index 2bcafad53..ed5ccfc6d 100644 --- a/tests/context/test_ip_utils.py +++ b/tests/context/test_ip_utils.py @@ -488,9 +488,11 @@ def test_whois_pdext(mock_asn_whois_query, net_df): net_df = net_df.head(25) mock_asn_whois_query.return_value = ASN_RESPONSE respx.get(re.compile(r"http://rdap\.arin\.net/.*")).respond(200, json=RDAP_RESPONSE) - results = net_df.mp_whois.lookup(ip_column="AllExtIPs") - check.equal(len(results), len(net_df)) - check.is_in("ASNDescription", results.columns) + + # Use mp.whois() instead of deprecated mp_whois accessor + results = net_df.mp.whois(ip_column="AllExtIPs") + # Results are merged back into original dataframe + check.is_in("AsnDescription", results.columns) results2 = net_df.mp.whois(ip_column="AllExtIPs", asn_col="asn", whois_col="whois") check.equal(len(results2), len(net_df)) diff --git a/tests/data/drivers/test_kql_driver.py b/tests/data/drivers/test_kql_driver.py deleted file mode 100644 index 631f1ec8a..000000000 --- a/tests/data/drivers/test_kql_driver.py +++ /dev/null @@ -1,372 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""KQL driver query test class.""" -import io -from contextlib import redirect_stdout -from unittest.mock import patch - -import pandas as pd -import pytest -import pytest_check as check -from adal.adal_error import AdalError -from Kqlmagic.kql_engine import KqlEngineError -from Kqlmagic.kql_response import KqlError -from Kqlmagic.my_aad_helper import AuthenticationError - -from msticpy.common.exceptions import ( - MsticpyDataQueryError, - MsticpyKqlConnectionError, - MsticpyNoDataSourceError, - MsticpyNotConnectedError, -) -from msticpy.data.core.query_defns import DataEnvironment -from msticpy.data.drivers import import_driver, kql_driver - -# from Kqlmagic import kql as kql_exec - - -KqlDriver = import_driver(DataEnvironment.MSSentinel_Legacy) - -# from msticpy.data.drivers.kql_driver import KqlDriver -GET_IPYTHON_PATCH = KqlDriver.__module__ + ".get_ipython" - - -# pylint: disable=too-many-branches, too-many-return-statements -# pylint: disable=no-self-use, redefined-outer-name - - -class KqlResultTest: - """Test Kql result class.""" - - def __init__(self, code=0, partial=False, status="success"): - """Create instance.""" - self.completion_query_info = {"StatusCode": code, "StatusDescription": status} - self.is_partial_table = partial - - def to_dataframe(self): - """Convert dataframe.""" - return pd.DataFrame() - - -class _MockIPython: - """IPython get_ipython mock.""" - - def find_magic(self, magic): - """Return None if magic isn't == kql.""" - if magic == "kql": - return "Kqlmagic" - return None - - def run_line_magic(self, magic, line): - """Mock run line magic.""" - return self._run_magic(magic, line) - - def run_cell_magic(self, magic, line, cell): - """Mock run cell magic.""" - content = cell or line - return self._run_magic(magic, content) - - @staticmethod # noqa: MC0001 - def _run_magic(magic, content): - if magic == "reload_ext": - return None - if magic == "config": - if "=" in content: - return "dummy_setting" - return True - - check.equal(magic, "kql") - return kql_exec(content) - - -def kql_exec(content, options=None): - """Mock kql_exec function.""" - del options - if "--config" in content: - if "=" in content: - conf_item, conf_value = content.replace("--config", "").strip().split("=") - return {conf_item: conf_value} - _, conf_item = content.split() - return {conf_item: True} - - if "--conn" in content: - return [" * 1234"] - - if "KqlErrorUnk" in content: - resp = '{"error": {"code": "UnknownError"}}' - raise KqlError(http_response=resp, message=resp) - if "KqlErrorWS" in content: - resp = '{"error": {"code": "WorkspaceNotFoundError"}}' - raise KqlError(http_response=resp, message=resp) - if "KqlEngineError" in content: - raise KqlEngineError("Test Error") - if "AdalErrorUnk" in content: - resp = {"error_description": "unknown error"} - raise AdalError("Test Error", error_response=resp) - if "AdalErrorNR" in content: - raise AdalError("Test Error") - if "AdalErrorPoll" in content: - raise AdalError("Unexpected polling state code_expired") - if "AuthenticationError" in content: - raise AuthenticationError("Test Error") - - if content == "--schema": - return { - "table1": {"field1": int, "field2": str}, - "table2": {"field1": int, "field2": str}, - } - - if "query_partial" in content: - return KqlResultTest(code=0, partial=True, status="partial") - if "query_failed" in content: - return KqlResultTest(code=1, partial=False, status="failed") - - return KqlResultTest(code=0, partial=False, status="success") - - -KQL_EXEC_PATCH = (kql_driver, "kql_exec", kql_exec) - - -class AzCredentials: - """Mock credentials class.""" - - class ModernCred: - """Mock modern credentials class.""" - - class Token: - """Mocked token class.""" - - token = "Token" # nosec - - @classmethod - def get_token(cls, *args, **kwargs): - """Return the token.""" - del args, kwargs - return cls.Token() - - @property - def credentials(self): - """Return mocked credentials list.""" - return ["cred1", "cred2", "cred3"] - - @property - def modern(self): - """Return the modern credentials.""" - return self.ModernCred() - - -def az_connect(*args, **kwargs): - """Mock the az_connect function.""" - del args, kwargs - return AzCredentials() - - -AZ_CONNECT_PATH = (kql_driver, "az_connect", az_connect) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_load(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - - kql_driver = KqlDriver() - check.is_true(kql_driver.loaded) - - kql_driver = KqlDriver(connection_str="la://connection") - check.is_true(kql_driver.loaded) - check.is_true(kql_driver.connected) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_connect(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - check.is_true(kql_driver.loaded) - - kql_driver.connect(connection_str="la://connection") - check.is_true(kql_driver.connected) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_connect_no_cs(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - check.is_true(kql_driver.loaded) - try: - kql_driver.connect() - check.is_in("loganalytics://code()", kql_driver.current_connection) - except KeyError: - # This is expected to fail occasionally because other tests - # may have changed the configuration. - pass - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_connect_kql_exceptions(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect(connection_str="la://connection+KqlErrorUnk") - check.is_in("Kql response error", mp_ex.value.args) - check.is_false(kql_driver.connected) - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect( - connection_str="la://connection.workspace('1234').tenant(KqlErrorWS)" - ) - check.is_in("unknown workspace", mp_ex.value.args) - check.is_false(kql_driver.connected) - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect( - connection_str="la://connection.workspace('1234').tenant(KqlEngineError)" - ) - check.is_in("kql connection error", mp_ex.value.args) - check.is_false(kql_driver.connected) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_connect_adal_exceptions(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect(connection_str="la://connection+AdalErrorUnk") - check.is_in("could not authenticate to tenant", mp_ex.value.args) - check.is_false(kql_driver.connected) - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect(connection_str="la://connection+AdalErrorNR") - check.is_in("could not authenticate to tenant", mp_ex.value.args) - check.is_in("Full error", str(mp_ex.value.args)) - check.is_false(kql_driver.connected) - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect(connection_str="la://connection+AdalErrorPoll") - check.is_in("authentication timed out", mp_ex.value.args) - check.is_false(kql_driver.connected) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_connect_authn_exceptions(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - - with pytest.raises(MsticpyKqlConnectionError) as mp_ex: - kql_driver.connect(connection_str="la://connection+AuthenticationError") - check.is_in("authentication failed", mp_ex.value.args) - check.is_false(kql_driver.connected) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_schema(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - kql_driver.connect(connection_str="la://connection") - - check.is_in("table1", kql_driver.schema) - check.is_in("table2", kql_driver.schema) - check.is_in("field1", kql_driver.schema["table1"]) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_query_not_connected(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - - with pytest.raises(MsticpyNotConnectedError) as mp_ex: - kql_driver.query("test") - check.is_in("not connected to a Workspace", mp_ex.value.args) - check.is_false(kql_driver.connected) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_query_failed(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - kql_driver.connect(connection_str="la://connection") - - with pytest.raises(MsticpyDataQueryError) as mp_ex: - kql_driver.query("test query_failed") - arg_str = "\n".join(str(arg) for arg in mp_ex.value.args) - check.is_in("Query:", arg_str) - check.is_in("test query_failed", arg_str) - check.is_in("Query failed", arg_str) - check.is_in( - "https://msticpy.readthedocs.io/en/latest/DataAcquisition.html", arg_str - ) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_query_success(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - kql_driver.connect(connection_str="la://connection") - - result_df = kql_driver.query("test query") - check.is_instance(result_df, pd.DataFrame) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_query_partial(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - kql_driver.connect(connection_str="la://connection") - - output = io.StringIO() - with redirect_stdout(output): - result_df = kql_driver.query("test query_partial") - check.is_instance(result_df, pd.DataFrame) - check.is_in("Warning - query returned partial", output.getvalue()) - - -@patch(GET_IPYTHON_PATCH) -@patch.object(*KQL_EXEC_PATCH) -@patch.object(*AZ_CONNECT_PATH) -def test_kql_query_no_table(get_ipython): - """Check loaded true.""" - get_ipython.return_value = _MockIPython() - kql_driver = KqlDriver() - kql_driver.connect(connection_str="la://connection") - - with pytest.raises(MsticpyNoDataSourceError) as mp_ex: - query_source = {"args.table": "table3"} - kql_driver.query("test query", query_source=query_source) - - check.is_in("table3 not found.", mp_ex.value.args) diff --git a/tests/data/drivers/test_kusto_driver.py b/tests/data/drivers/test_kusto_driver.py deleted file mode 100644 index 6d2f82ecd..000000000 --- a/tests/data/drivers/test_kusto_driver.py +++ /dev/null @@ -1,254 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""Kusto driver unit tests.""" -from unittest.mock import Mock - -import pytest -import pytest_check as check - -from msticpy.common.exceptions import MsticpyParameterError, MsticpyUserConfigError -from msticpy.data.core.data_providers import QueryProvider -from msticpy.data.drivers.kql_driver import KqlDriver -from msticpy.data.drivers.kusto_driver import KustoDriver - -from ...unit_test_lib import custom_mp_config, get_test_data_path - -__author__ = "Ian Hellen" - -# pylint: disable=redefined-outer-name, protected-access -pytestmark = [ - pytest.mark.filterwarnings("ignore::UserWarning"), - pytest.mark.filterwarnings("ignore::DeprecationWarning"), -] - -_KUSTO_SETTINGS = """ -DataProviders: - Kusto-MSTIC: - args: - Cluster: https://msticti.kusto.windows.net - ClientId: UUID - TenantId: UUID - ClientSecret: [PLACEHOLDER] - - Kusto-AppAuthCluster: - args: - Cluster: https://msticapp.kusto.windows.net - ClientId: UUID - TenantId: UUID - ClientSecret: [PLACEHOLDER] - -""" - - -@pytest.fixture -def kusto_qry_prov(): - """Return query provider with query paths.""" - qry_path = str(get_test_data_path().joinpath("kusto_legacy")) - msticpy_config = get_test_data_path().joinpath("msticpyconfig.yaml") - with custom_mp_config(msticpy_config): - return QueryProvider("Kusto_Legacy", query_paths=[qry_path]) - - -_TEST_CON_STR = [ - "azure_data-Explorer://", - "tenant='69d28fd7-42a5-48bc-a619-af56397b9f28';", - "clientid='69d28fd7-42a5-48bc-a619-af56397b1111';", - "clientsecret='[PLACEHOLDER]';", - "cluster='https://msticapp.kusto.windows.net';", - "database='scrubbeddata'", -] -_KUSTO_TESTS = [ - ("no_params", {}), - ("cluster_uri", {"cluster": "https://msticapp.kusto.windows.net"}), - ("cluster", {"cluster": "msticapp"}), - ("database", {"database": "scrubbeddata"}), - ( - "both", - { - "cluster": "https://msticapp.kusto.windows.net", - "database": "scrubbeddata", - }, - ), - ("con_str", {"connection_str": "".join(_TEST_CON_STR)}), -] - - -def _mock_connect(self, *args, **kwargs): - """Mock connect for KqlDriver""" - print(args, kwargs) - - -@pytest.mark.parametrize("inst, qry_args", _KUSTO_TESTS) -def test_kusto_driver_connect(inst, qry_args, monkeypatch, kusto_qry_prov): - """Test class Kusto load and execute query driver.""" - qry_prov = kusto_qry_prov - driver = qry_prov._query_provider - check.is_instance(driver, KustoDriver) - check.greater_equal(len(qry_prov.list_queries()), 4) - - print(inst) - # set up mock - mock_driver = Mock(KqlDriver) - mock_driver.connect.return_value = None - monkeypatch.setattr(driver.__class__.__mro__[1], "connect", _mock_connect) - - # Call connect - driver.connect(**qry_args) - if inst in ("both", "con_str"): - # We expect successful connection with either both cluster - # and database params or full connection string - check.is_not_none(driver.current_connection) - for expected in _TEST_CON_STR: - check.is_in(expected, driver.current_connection) - else: - check.is_none(driver.current_connection) - - -@pytest.mark.parametrize("inst, qry_args", _KUSTO_TESTS) -def test_kusto_driver_queries(inst, qry_args, monkeypatch, kusto_qry_prov): - """Test class Kusto load and execute query driver.""" - qry_prov = kusto_qry_prov - driver = qry_prov._query_provider - check.is_instance(driver, KustoDriver) - check.greater_equal(len(qry_prov.list_queries()), 4) - - print(inst) - # set up mock - mock_driver = Mock(KqlDriver) - mock_driver.query_with_results.return_value = "data", "success" - monkeypatch.setattr(driver, "query_with_results", mock_driver.query_with_results) - - # Run query - result = qry_prov.AppAuthCluster.scrubbeddata.list_host_processes( - host_name="test", **qry_args - ) - mock_qry_func = driver.query_with_results - mock_qry_func.assert_called_once() - check.equal(result, "data") - check.is_in('DeviceName has "test"', mock_qry_func.call_args[0][0]) - check.is_in("where Timestamp >= datetime(2", mock_qry_func.call_args[0][0]) - for expected in _TEST_CON_STR: - check.is_in(expected, driver.current_connection) - - -_TEST_CON_STR_INTEG = [ - "azure_data-Explorer://", - "code;", - "cluster='https://mstic.kusto.windows.net';", - "database='scrubbeddata'", -] -_KUSTO_TESTS_INTEG = [ - ("no_params", {}), - ("cluster_uri", {"cluster": "https://mstic.kusto.windows.net"}), - ("cluster", {"cluster": "mstic"}), - ("database", {"database": "scrubbeddata"}), - ( - "both", - { - "cluster": "https://mstic.kusto.windows.net", - "database": "scrubbeddata", - }, - ), - ("con_str", {"connection_str": "".join(_TEST_CON_STR_INTEG)}), -] - - -@pytest.mark.parametrize("inst, qry_args", _KUSTO_TESTS_INTEG) -def test_kusto_driver_integ_auth(inst, qry_args, monkeypatch, kusto_qry_prov): - """Test class Kusto load and execute query driver.""" - qry_prov = kusto_qry_prov - driver = qry_prov._query_provider - check.is_instance(driver, KustoDriver) - check.greater_equal(len(qry_prov.list_queries()), 4) - - print(inst) - # set up mock - mock_driver = Mock(KqlDriver) - mock_driver.query_with_results.return_value = "data", "success" - monkeypatch.setattr(driver, "query_with_results", mock_driver.query_with_results) - - # Run query - result = qry_prov.IntegAuthCluster.scrubbeddata.list_host_processes( - host_name="test", **qry_args - ) - mock_qry_func = driver.query_with_results - mock_qry_func.assert_called_once() - check.equal(result, "data") - check.is_in('DeviceName has "test"', mock_qry_func.call_args[0][0]) - check.is_in("where Timestamp >= datetime(2", mock_qry_func.call_args[0][0]) - for expected in _TEST_CON_STR_INTEG: - check.is_in(expected, driver.current_connection) - - -@pytest.mark.parametrize("inst, qry_args", _KUSTO_TESTS) -def test_kusto_driver_params_fail(inst, qry_args, monkeypatch): - """Test with parameters but missing config.""" - qry_path = str(get_test_data_path().joinpath("kusto_legacy")) - msticpy_config = get_test_data_path().joinpath("msticpyconfig-nokusto.yaml") - with custom_mp_config(msticpy_config): - qry_prov = QueryProvider("Kusto_Legacy", query_paths=[qry_path]) - driver = qry_prov._query_provider - - print(inst) - # set up mock - mock_driver = Mock(KqlDriver) - mock_driver.query_with_results.return_value = "data", "success" - monkeypatch.setattr(driver, "query_with_results", mock_driver.query_with_results) - - if inst == "con_str": - # No configuration so only supplying full connection string should work - result = qry_prov.AppAuthCluster.scrubbeddata.list_host_processes( - host_name="test", **qry_args - ) - mock_qry_func = driver.query_with_results - mock_qry_func.assert_called_once() - check.equal(result, "data") - check.is_in('DeviceName has "test"', mock_qry_func.call_args[0][0]) - check.is_in("where Timestamp >= datetime(2", mock_qry_func.call_args[0][0]) - for expected in _TEST_CON_STR: - check.is_in(expected, driver.current_connection) - else: - # Everything else should throw a configuration error. - with pytest.raises(MsticpyUserConfigError): - result = qry_prov.AppAuthCluster.scrubbeddata.list_host_processes( - host_name="test", **qry_args - ) - - -@pytest.mark.parametrize("inst, qry_args", _KUSTO_TESTS) -def test_kusto_driver_query_fail(inst, qry_args, monkeypatch, kusto_qry_prov): - """Test with queries + params with incomplete metadata.""" - qry_prov = kusto_qry_prov - driver = qry_prov._query_provider - check.is_instance(driver, KustoDriver) - check.greater_equal(len(qry_prov.list_queries()), 4) - - check.is_true(hasattr(qry_prov.AppAuthClustera.scrubbeddata, "query_new_alias")) - check.is_true(hasattr(qry_prov.scrubbeddata, "bad_query_fam_no_dot")) - print(inst) - # set up mock - mock_driver = Mock(KqlDriver) - mock_driver.query_with_results.return_value = "data", "success" - monkeypatch.setattr(driver, "query_with_results", mock_driver.query_with_results) - - if inst in ("both", "cluster", "con_str", "cluster_uri"): - # run query - result = qry_prov.AppAuthCluster.scrubbeddata.bad_query_no_cluster( - cmd_line="test", **qry_args - ) - mock_qry_func = driver.query_with_results - mock_qry_func.assert_called_once() - check.equal(result, "data") - check.is_in('ProcessCommandLine contains "test"', mock_qry_func.call_args[0][0]) - check.is_in("where Timestamp >= datetime(2", mock_qry_func.call_args[0][0]) - for expected in _TEST_CON_STR: - check.is_in(expected, driver.current_connection) - else: - # Everything else should throw a parameter error. - with pytest.raises(MsticpyParameterError): - qry_prov.AppAuthCluster.scrubbeddata.bad_query_no_cluster( - cmd_line="test", **qry_args - ) diff --git a/tests/data/drivers/test_mdatp_driver.py b/tests/data/drivers/test_mdatp_driver.py index 5bce85a7d..8e782a19d 100644 --- a/tests/data/drivers/test_mdatp_driver.py +++ b/tests/data/drivers/test_mdatp_driver.py @@ -48,6 +48,7 @@ def test_select_api_mde() -> None: assert cfg.oauth_v2 is True +@pytest.mark.filterwarnings("ignore:M365 Defender") def test_select_api_m365d() -> None: """Test API selection for M365 Defender unified environment.""" # Note this now reverts to MDE parameters diff --git a/tests/data/drivers/test_odata_drivers.py b/tests/data/drivers/test_odata_drivers.py index 61c4c5a90..9b0ad5323 100644 --- a/tests/data/drivers/test_odata_drivers.py +++ b/tests/data/drivers/test_odata_drivers.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- """Miscellaneous data provider driver tests.""" +import sys from unittest.mock import Mock, patch import pandas as pd @@ -30,7 +31,10 @@ MP_PATH = str(get_test_data_path().parent.joinpath("msticpyconfig-test.yaml")) # pylint: disable=protected-access -pytestmark = pytest.mark.filterwarnings("ignore::UserWarning") +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore:M365 Defender"), +] _JSON_RESP = { "token_type": "Bearer", @@ -134,6 +138,9 @@ def test_mde_connect(httpx, env, api): @pytest.mark.parametrize("env, api, con_str", _MDE_CONNECT_STR) @patch("msticpy.data.drivers.odata_driver.httpx") +@pytest.mark.skipif( + sys.platform.startswith("linux"), reason="File locking issue on Linux" +) def test_mde_connect_str(httpx, env, api, con_str): """Test security graph driver.""" driver_cls = import_driver(DataEnvironment.parse(env)) diff --git a/tests/data/drivers/test_sumologic_driver.py b/tests/data/drivers/test_sumologic_driver.py index baa270675..27446000b 100644 --- a/tests/data/drivers/test_sumologic_driver.py +++ b/tests/data/drivers/test_sumologic_driver.py @@ -267,7 +267,7 @@ def sumologic_drv(): @pytest.mark.parametrize(("query", "expected"), _QUERY_TESTS) def test_sumologic_query(sumologic_drv, query, expected): """Check queries with different outcomes.""" - end = datetime.utcnow() + end = datetime.now(timezone.utc) start = end - timedelta(1) if query in ("MessageFail", "RecordFail", "Failjob", "RecordFail | count records"): with pytest.raises(MsticpyConnectionError) as mp_ex: @@ -353,6 +353,7 @@ def test_sumologic_query_params(sumologic_drv, params, expected): @patch(SUMOLOGIC_SVC, SumologicService) +@pytest.mark.filterwarnings("ignore:datetime.datetime.utcnow") @pytest.mark.parametrize("ext", ("xlsx", "csv")) def test_sumologic_query_export(sumologic_drv, tmpdir, ext): """Check queries with different parameters.""" diff --git a/tests/init/pivot/conftest.py b/tests/init/pivot/conftest.py index 59de3c6cd..b108848bd 100644 --- a/tests/init/pivot/conftest.py +++ b/tests/init/pivot/conftest.py @@ -20,15 +20,15 @@ __author__ = "Ian Hellen" -# pylint: disable=redefined-outer-name, protected-access +# pylint: disable=redefined-outer-name, protected-access, invalid-name -_KQL_IMP_OK = False +_AZURE_MONITOR_OK = False with contextlib.suppress(ImportError): # pylint: disable=unused-import - from msticpy.data.drivers import kql_driver + from msticpy.data.drivers import azure_monitor_driver - del kql_driver - _KQL_IMP_OK = True + del azure_monitor_driver + _AZURE_MONITOR_OK = True _SPLUNK_IMP_OK = False with contextlib.suppress(ImportError): from msticpy.data.drivers import splunk_driver @@ -63,7 +63,7 @@ def create_data_providers(): ): with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) - if _KQL_IMP_OK: + if _AZURE_MONITOR_OK: prov_dict["az_sent_prov"] = QueryProvider("MSSentinel") prov_dict["mdatp_prov"] = QueryProvider("MDE") if _SPLUNK_IMP_OK: diff --git a/tests/init/pivot/test_pivot.py b/tests/init/pivot/test_pivot.py index 36eafa3bf..d464ec00c 100644 --- a/tests/init/pivot/test_pivot.py +++ b/tests/init/pivot/test_pivot.py @@ -26,13 +26,6 @@ pytestmark = pytest.mark.filterwarnings("ignore::UserWarning") # pylint: disable=redefined-outer-name, protected-access -_KQL_IMP_OK = False -with contextlib.suppress(ImportError): - # pylint: disable=unused-import - from msticpy.data.drivers import kql_driver - - del kql_driver - _KQL_IMP_OK = True _SPLUNK_IMP_OK = False with contextlib.suppress(ImportError): from msticpy.data.drivers import splunk_driver @@ -48,8 +41,6 @@ except ImportError: ip_stack_cls = None -pytestmark = pytest.mark.skipif(not _KQL_IMP_OK, reason="Partial msticpy install") - def _test_create_pivot_namespace(data_providers): """Test instantiating Pivot with namespace arg.""" diff --git a/tests/init/pivot/test_pivot_browser.py b/tests/init/pivot/test_pivot_browser.py index cb7cd4188..725451743 100644 --- a/tests/init/pivot/test_pivot_browser.py +++ b/tests/init/pivot/test_pivot_browser.py @@ -4,6 +4,9 @@ # license information. # -------------------------------------------------------------------------- """Pivot pipeline browser UI.""" +import sys + +import pytest import pytest_check as check try: @@ -20,6 +23,9 @@ __author__ = "Ian Hellen" +@pytest.mark.skipif( + sys.platform.startswith("linux"), reason="Requires powershell.exe (Windows only)" +) def test_pivot_browser(create_pivot): """Test pivot browser.""" browser = PivotBrowser() diff --git a/tests/init/pivot/test_pivot_data_queries_create.py b/tests/init/pivot/test_pivot_data_queries_create.py index 8d48e1c97..07ae8c046 100644 --- a/tests/init/pivot/test_pivot_data_queries_create.py +++ b/tests/init/pivot/test_pivot_data_queries_create.py @@ -24,14 +24,6 @@ add_queries_to_entities, ) -_KQL_IMP_OK = False -with contextlib.suppress(ImportError): - # pylint: disable=unused-import - from msticpy.data.drivers import kql_driver - - del kql_driver - _KQL_IMP_OK = True - __author__ = "Ian Hellen" # pylint: disable=redefined-outer-name @@ -45,7 +37,6 @@ def azure_sentinel(): return QueryProvider("AzureSentinel") -@pytest.mark.skipif(not _KQL_IMP_OK, reason="Partial msticpy install") def test_create_query_functions(azure_sentinel): """Test basic creation of query functions class.""" az_qry_funcs = PivotQueryFunctions(azure_sentinel) @@ -54,7 +45,6 @@ def test_create_query_functions(azure_sentinel): check.greater_equal(len(az_qry_funcs.query_params), 70) -@pytest.mark.skipif(not _KQL_IMP_OK, reason="Partial msticpy install") def test_query_functions_methods(azure_sentinel): """Test attributes of retrieved functions.""" az_qry_funcs = PivotQueryFunctions(azure_sentinel) @@ -244,7 +234,6 @@ def test_create_pivot_func_df(test_input, expected): ] -@pytest.mark.skipif(not _KQL_IMP_OK, reason="Partial msticpy install") @pytest.mark.parametrize("entity, expected", _ENT_QUERY_FUNC) def test_add_queries_to_entities(entity, expected, azure_sentinel): """Test query functions successfully added to entities.""" diff --git a/tests/init/pivot/test_pivot_input_types.py b/tests/init/pivot/test_pivot_input_types.py index 255dd7b62..a69fa2801 100644 --- a/tests/init/pivot/test_pivot_input_types.py +++ b/tests/init/pivot/test_pivot_input_types.py @@ -119,6 +119,7 @@ def data_providers(): ] +@pytest.mark.filterwarnings("ignore:GeoIpLookup") @pytest.mark.parametrize("test_case", _PIVOT_QUERIES) def test_pivot_funcs_value(create_pivot, test_case): """Test calling function with value.""" diff --git a/tests/init/pivot/test_pivot_register.py b/tests/init/pivot/test_pivot_register.py index 94b777d41..fec56f4b0 100644 --- a/tests/init/pivot/test_pivot_register.py +++ b/tests/init/pivot/test_pivot_register.py @@ -249,6 +249,7 @@ def data_providers(): @respx.mock +@pytest.mark.filterwarnings("ignore:GeoIpLookup") @pytest.mark.parametrize("test_case", _ENTITY_QUERIES) @patch("msticpy.context.ip_utils._asn_whois_query") def test_entity_attr_funcs_entity(mock_asn_whois_query, create_pivot, test_case): @@ -267,6 +268,7 @@ def test_entity_attr_funcs_entity(mock_asn_whois_query, create_pivot, test_case) @respx.mock +@pytest.mark.filterwarnings("ignore:GeoIpLookup") @pytest.mark.parametrize("test_case", _ENTITY_QUERIES) @patch("msticpy.context.ip_utils._asn_whois_query") def test_entity_attr_funcs_value(mock_asn_whois_query, create_pivot, test_case): @@ -284,6 +286,7 @@ def test_entity_attr_funcs_value(mock_asn_whois_query, create_pivot, test_case): @respx.mock +@pytest.mark.filterwarnings("ignore:GeoIpLookup") @pytest.mark.parametrize("test_case", _ENTITY_QUERIES) @patch("msticpy.context.ip_utils._asn_whois_query") def test_entity_attr_funcs_itbl(mock_asn_whois_query, create_pivot, test_case): @@ -301,6 +304,7 @@ def test_entity_attr_funcs_itbl(mock_asn_whois_query, create_pivot, test_case): @respx.mock +@pytest.mark.filterwarnings("ignore:GeoIpLookup") @pytest.mark.parametrize("test_case", _ENTITY_QUERIES) @patch("msticpy.context.ip_utils._asn_whois_query") def test_entity_attr_funcs_df(mock_asn_whois_query, create_pivot, test_case): diff --git a/tests/init/test_azure_ml_tools.py b/tests/init/test_azure_ml_tools.py index 9bbf083a4..5843647a7 100644 --- a/tests/init/test_azure_ml_tools.py +++ b/tests/init/test_azure_ml_tools.py @@ -49,12 +49,9 @@ def aml_file_sys(tmpdir_factory): _MP_FUT_VER = ".".join(f"{int(v) + 1}" for v in _CURR_VERSION) _MP_FUT_VER_T = tuple(int(v) + 1 for v in _CURR_VERSION) -_EXP_ENV = { - "KQLMAGIC_EXTRAS_REQUIRE": "jupyter-basic", - "KQLMAGIC_AZUREML_COMPUTE": "myhost", -} +# Kqlmagic environment variables are no longer set +_EXP_ENV = {} _EXP_ENV_JPX = _EXP_ENV.copy() -_EXP_ENV_JPX["KQLMAGIC_EXTRAS_REQUIRE"] = "jupyter-extended" class _PyOs: @@ -64,20 +61,6 @@ def __init__(self): self.environ: Dict[str, Any] = {} -class _ipython: - """Emulation for IPython shell.""" - - pgo_installed = False - - def run_line_magic(self, *args, **kwargs): - """Return package list.""" - del kwargs - if "apt list" in args: - if self.pgo_installed: - return ["libgirepository1.0-dev", "gir1.2-secret-1"] - return [] - - CheckVers = namedtuple("CheckVers", "py_req, mp_req, extras, is_aml, excep, env") CHECK_VERS = [ @@ -112,7 +95,8 @@ def test_check_versions(monkeypatch, aml_file_sys, check_vers): # monkeypatch for various test cases _os = _PyOs() monkeypatch.setattr(aml, "os", _os) - monkeypatch.setattr(aml, "get_ipython", _ipython) + # get_ipython is no longer used in azure_ml_tools after Kqlmagic removal + # monkeypatch.setattr(aml, "get_ipython", _ipython) monkeypatch.setattr(aml, "_get_vm_fqdn", lambda: "myhost") if sys.version_info[:3] < (3, 10): monkeypatch.setattr(sys, "version_info", VersionInfo(3, 10, 0, "final", 0)) diff --git a/tests/vis/test_entity_graph.py b/tests/vis/test_entity_graph.py index ad71535ae..7ebd4da93 100644 --- a/tests/vis/test_entity_graph.py +++ b/tests/vis/test_entity_graph.py @@ -5,6 +5,7 @@ # -------------------------------------------------------------------------- """Test module for EntityGraph.""" import pandas as pd +import pytest from bokeh.models.layouts import Column try: @@ -204,6 +205,7 @@ def test_plot(): assert isinstance(tl_plot, Column) +@pytest.mark.filterwarnings("ignore:no explicit representation of timezones") def test_df_plot(): """Test plotting from DataFrame""" plot = sent_incidents.mp_plot.incident_graph() diff --git a/tests/vis/test_folium.py b/tests/vis/test_folium.py index 9c2a898ac..2bf3d2226 100644 --- a/tests/vis/test_folium.py +++ b/tests/vis/test_folium.py @@ -5,6 +5,7 @@ # -------------------------------------------------------------------------- """Unit tests for Folium wrapper.""" import math +import sys from pathlib import Path from typing import Any, Optional @@ -109,6 +110,7 @@ def test_centering_algorithms(geo_loc_df): check.is_true(math.isclose(center[1], -87.36079411764706)) +@pytest.mark.filterwarnings("ignore:GeoIpLookup") def test_add_ips(geo_loc_df): """Test adding list of IPs.""" ips = geo_loc_df.AllExtIPs @@ -235,7 +237,7 @@ def create_geo_entity(row): def icon_map_func(key): - """Test function for plot_map""" + """Test function for plot_map.""" return icon_map.get(key, icon_map.get("default")) @@ -279,7 +281,11 @@ class PlotMapTest: _PM_IDS = [pmt.name for pmt in _PM_TEST_PARAMS] +@pytest.mark.filterwarnings("ignore:GeoIpLookup") @pytest.mark.parametrize("plot_test", _PM_TEST_PARAMS, ids=_PM_IDS) +@pytest.mark.skipif( + sys.platform.startswith("linux"), reason="GeoIP database not configured in Docker" +) def test_plot_map(plot_test, geo_loc_df): """Test plot_map with different parameters.""" plot_kwargs = attr.asdict(plot_test) diff --git a/tests/vis/test_morph_charts.py b/tests/vis/test_morph_charts.py deleted file mode 100644 index a5c28aa79..000000000 --- a/tests/vis/test_morph_charts.py +++ /dev/null @@ -1,98 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""morph_charts test class.""" -import os -from pathlib import Path -from unittest.mock import call, patch - -import IPython -import pandas as pd -import pytest - -from msticpy.common.exceptions import MsticpyException -from msticpy.vis.morph_charts import MorphCharts - -from ..unit_test_lib import get_test_data_path - -_TEST_DATA = get_test_data_path() - - -@pytest.fixture -def test_morph(): - """Create MorphCharts objcet.""" - return MorphCharts() - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -@patch("builtins.print") -def test_chart_details(mocked_print, test_morph): - """Test case.""" - with pytest.raises(KeyError): - assert test_morph.get_chart_details("xxx") - test_morph.get_chart_details("SigninsChart") - assert mocked_print.mock_calls == [ - call( - "SigninsChart", - ":", - "\n", - "Charts for visualizing Azure AD Signin Logs.", - "\n", - "Query: ", - "Azure.list_all_signins_geo", - ) - ] - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -@patch("builtins.print") -def test_list_charts(mocked_print, test_morph): - """Test case.""" - test_morph.list_charts() - assert mocked_print.mock_calls == [call("SigninsChart")] - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -@patch("builtins.print") -def test_search_charts_f(mocked_print, test_morph): - """Test case.""" - test_morph.search_charts("testing") - assert mocked_print.mock_calls == [call("No matching charts found")] - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -@patch("builtins.print") -def test_search_charts_s(mocked_print, test_morph): - """Test case.""" - test_morph.search_charts("signinLogs") - assert mocked_print.mock_calls == [ - call( - "SigninsChart", - ":", - "\n", - "Charts for visualizing Azure AD Signin Logs.", - ) - ] - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_display(test_morph): - """Test case.""" - test_file = Path(_TEST_DATA).joinpath("morph_test.csv") - test_data = pd.read_csv(test_file, index_col=0) - output = test_morph.display(data=test_data, chart_name="SigninsChart") - assert isinstance(output, IPython.lib.display.IFrame) - assert os.path.isdir(Path.cwd().joinpath("morphchart_package")) is True - assert ( - os.path.isfile(Path.cwd().joinpath(*["morphchart_package", "description.json"])) - is True - ) - assert ( - os.path.isfile(Path.cwd().joinpath(*["morphchart_package", "query_data.csv"])) - is True - ) - with pytest.raises(MsticpyException): - assert test_morph.display(data=test_data, chart_name="test") - assert test_morph.display(data="test_data", chart_name="SigninsChart") diff --git a/tools/config2kv.py b/tools/config2kv.py index 1740664f0..b73de46f5 100644 --- a/tools/config2kv.py +++ b/tools/config2kv.py @@ -67,7 +67,7 @@ def _read_config_settings(conf_file): if not conf_file: raise ValueError("Configuration file not found.") print(conf_file) - with open(conf_file, "r", encoding="utf-8") as conf_hdl: + with open(conf_file, encoding="utf-8") as conf_hdl: cur_settings = yaml.safe_load(conf_hdl) # temporarily set env var to point to conf_file @@ -93,7 +93,7 @@ def _format_kv_name(setting_path): return re.sub("[^0-9a-zA-Z-]", "-", setting_path) -def _get_config_secrets(cur_settings, section_name, sec_names): # noqa: MC0001 +def _get_config_secrets(cur_settings, section_name, sec_names): kv_dict = {} sec_key_names = ["authkey", "apiid", "password", "clientsecret"] if sec_names: diff --git a/tools/toollib/url_checker.py b/tools/toollib/url_checker.py index b364c7b61..d06c43143 100644 --- a/tools/toollib/url_checker.py +++ b/tools/toollib/url_checker.py @@ -70,7 +70,7 @@ def check_url(url: str) -> UrlResult: # pylint: disable=too-many-locals, too-many-branches -def check_site( # noqa: MC0001 +def check_site( page_url: str, all_links: bool = False, top_root: Optional[str] = None, @@ -273,7 +273,7 @@ def check_md_document(doc_path: str) -> Dict[str, UrlResult]: Dictionary of checked links """ - with open(doc_path, "r", encoding="utf-8") as doc_file: + with open(doc_path, encoding="utf-8") as doc_file: body_markdown = doc_file.read() md_content = markdown.markdown(body_markdown) soup = BeautifulSoup(md_content, "html.parser")