diff --git a/.github/actions/prepare_container_test/action.yml b/.github/actions/prepare_container_test/action.yml index de8e2d00a..095c1d988 100644 --- a/.github/actions/prepare_container_test/action.yml +++ b/.github/actions/prepare_container_test/action.yml @@ -6,10 +6,12 @@ runs: with: python-version: 3.7 - - name: Install pipenv + - name: Install dependencies shell: bash - run: pip install pipenv - working-directory: ./scripts + run: | + pip install -U pip + pip install pytest + pip install ./packages/client/libclient-py - name: Set up Docker buildx id: buildx diff --git a/.github/workflows/test-server-all.yml b/.github/workflows/test-server-all.yml index eba03c9d3..41167af8c 100644 --- a/.github/workflows/test-server-all.yml +++ b/.github/workflows/test-server-all.yml @@ -246,10 +246,7 @@ jobs: - uses: ./.github/actions/prepare_container_test - name: Run container test - run: | - pipenv sync - pipenv install --skip-lock ../../packages/client/libclient-py - pipenv run pytest ./tests/test_up.py -s -v -log-cli-level=DEBUG + run: pytest ./tests/test_up.py -s -v -log-cli-level=DEBUG working-directory: ./scripts/container_test medium_test_container_restart: @@ -284,10 +281,7 @@ jobs: - uses: ./.github/actions/prepare_container_test - name: Run container test - run: | - pipenv sync - pipenv install --skip-lock ../../packages/client/libclient-py - pipenv run pytest ./tests/test_restart_request.py::test_success_${{ matrix.test_name }}_with_restart -s -v -log-cli-level=DEBUG + run: pytest ./tests/test_restart_request.py::test_success_${{ matrix.test_name }}_with_restart -s -v -log-cli-level=DEBUG working-directory: ./scripts/container_test medium_test_container_down: @@ -321,10 +315,7 @@ jobs: - uses: ./.github/actions/prepare_container_test - name: Run container test - run: | - pipenv sync - pipenv install --skip-lock ../../packages/client/libclient-py - pipenv run pytest ./tests/test_down_request.py::test_failed_${{ matrix.test_name }}_with_down -s -v -log-cli-level=DEBUG + run: pytest ./tests/test_down_request.py::test_failed_${{ matrix.test_name }}_with_down -s -v -log-cli-level=DEBUG working-directory: ./scripts/container_test medium_test_container: diff --git a/.gitignore b/.gitignore index aa54e4ff7..42474435b 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,7 @@ __pycache__ **/db/*/** !**/db/**/.gitkeep -scripts/.env \ No newline at end of file +scripts/.env + +docs/**/_build/** +docs/**/_generated/** diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..ffee8e34a --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,34 @@ +# Read the Docs configuration file for Sphinx projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.7" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/libclient-py/source/conf.py + # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs + # builder: "dirhtml" + # Fail on all warnings to avoid broken references + # fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - method: pip + path: ./packages/client/libclient-py + extra_requirements: + - document diff --git a/docs/libclient-py/Makefile b/docs/libclient-py/Makefile new file mode 100644 index 000000000..82f26d63f --- /dev/null +++ b/docs/libclient-py/Makefile @@ -0,0 +1,34 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD = sphinx-build +SOURCEDIR = source +BUILDDIR = _build +APIDIR = ../../packages/client/libclient-py/quickmpc + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: autobuild +autobuild: + sphinx-autobuild $(SOURCEDIR) $(BUILDDIR)/html + +.PHONY: autogen +autogen: + sphinx-autogen source/**/*.rst + +.PHONY: clean +clean: + rm -rf $(BUILDDIR)/* + rm -rf $(SOURCEDIR)/reference/_generated/ diff --git a/docs/libclient-py/make.bat b/docs/libclient-py/make.bat new file mode 100644 index 000000000..747ffb7b3 --- /dev/null +++ b/docs/libclient-py/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/libclient-py/source/conf.py b/docs/libclient-py/source/conf.py new file mode 100644 index 000000000..58a645e07 --- /dev/null +++ b/docs/libclient-py/source/conf.py @@ -0,0 +1,45 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import quickmpc # noqa + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'quickmpc-libclient-py' +copyright = '2023, Acompany Co., Ltd.' +author = 'Acompany Co., Ltd.' +version = quickmpc.__version__ +release = quickmpc.__version__ + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + 'sphinx.ext.autodoc', # docstringから自動で取り込めるように + "sphinx.ext.autosummary", # summaryのリストを生成する + 'sphinx.ext.napoleon', # numpy styleを取り込めるように + 'sphinx_rtd_theme', # pageのテーマ +] + +templates_path = ['_templates'] +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'sphinx_rtd_theme' +html_static_path = ['_static'] + + +# -- Extension configuration ------------------------------------------------- +autosummary_generate = True +autodoc_typehints = "description" +autodoc_default_options = { + "members": True, + "inherited-members": True, + "exclude-members": "with_traceback", +} diff --git a/docs/libclient-py/source/index.rst b/docs/libclient-py/source/index.rst new file mode 100644 index 000000000..c37f62d9e --- /dev/null +++ b/docs/libclient-py/source/index.rst @@ -0,0 +1,28 @@ +.. quickmpc-libclient-py documentation master file, created by + sphinx-quickstart on Mon Jul 24 19:08:23 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +quickmpc-libclient-py's documentation +===================================== + +.. _QuickMPC: https://github.com/acompany-develop/QuickMPC + +quickmpc-libclient-py はMPC(Multi Party Computation)を行う `QuickMPC`_ を簡単に操作するためのPythonライブラリです. +pandas likeなinterfaceでMPCについて知らないユーザでも簡単に秘匿計算を行うことができます. + +.. toctree:: + :maxdepth: 1 + :caption: Contents: + + installation + quickstart + reference/index + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/libclient-py/source/installation.rst b/docs/libclient-py/source/installation.rst new file mode 100644 index 000000000..b9269fa57 --- /dev/null +++ b/docs/libclient-py/source/installation.rst @@ -0,0 +1,10 @@ +.. _installation: + +Installation +============ +quickmpc requires Python 3.7 or later to run. You can install quickmpc using pip: + +.. code-block:: bash + + $ pip install quickmpc + diff --git a/docs/libclient-py/source/quickstart.rst b/docs/libclient-py/source/quickstart.rst new file mode 100644 index 000000000..c2800ad95 --- /dev/null +++ b/docs/libclient-py/source/quickstart.rst @@ -0,0 +1,72 @@ +.. _quickstart: + +.. _quickmpc-libclient-py: https://github.com/acompany-develop/QuickMPC/tree/main/packages/client/libclient-py +.. _pandas: https://pandas.pydata.org/ + +Quickstart +========== +本ページでは `quickmpc-libclient-py`_ を用いたMPC(Multi Party Computation)の始め方を説明します. +まだquickmpc-libclient-pyをinstallしていない方は :ref:`Installation ` を参考にinstallをしてください. + +csvデータの操作 +--------------- +quickmpc-libclient-pyは `pandas`_ likeなinterfaceを提供します. +例えばcsvファイルの読み取りはpandasと同じように :py:func:`read_csv関数 ` により実現できます. +また, 返り値はpandas.DataFrameになっており,pandasと同じように加工することができます. + +.. code-block:: python3 + + import quickmpc.pandas as qpd + + df = qpd.read_csv("data.csv", index_col="ID") + df = df.applymap(lambda x: x%2) + + +MPCの開始 +--------- +:py:class:`QMPCクラス ` を使用して,あらかじめ用意した2つ以上のQuickMPCサーバに接続します. + +.. code-block:: python3 + + import quickmpc + + qmpc = quickmpc.QMPC([ + "http://~~~:~~~", + "http://~~~:~~~", + "http://~~~:~~~", + ]) + +読み込んだcsvデータを次のようにしてサーバに送信します. + +.. code-block:: python3 + + sdf = qmpc.send_to(df) + +自分以外のユーザが送信したデータは次のようにして取得します. + +.. code-block:: python3 + + data_id = "~~~~" # 他者が sdf.get_id() により出力したID + sdf = qmpc.load_from(data_id) + +使用するデータをすべてサーバに送信したら各種MPCを実行します. + +.. code-block:: python3 + + sdf_res = sdf.sum() + sdf_res = sdf.join(sdf_other) + +計算結果の取得 +-------------- +MPCの結果はファイルに出力するか,pandas.DataFrameに出力して引き続き加工することができます. + +.. code-block:: python3 + + df = sdf_res.to_csv("filename.csv") # ファイルに出力 + df = sdf_res.to_data_frame() # DataFrameに出力 + +一部の実行時間の長い計算は `progress` オプションを設定して計算ステータスをログに出力できます. + +.. code-block:: python3 + + df = sdf_res.to_data_frame(progress=True) diff --git a/docs/libclient-py/source/reference/index.rst b/docs/libclient-py/source/reference/index.rst new file mode 100644 index 000000000..b9635ab1f --- /dev/null +++ b/docs/libclient-py/source/reference/index.rst @@ -0,0 +1,14 @@ +.. _api_reference: + +API Reference +============= + +.. toctree:: + :maxdepth: 1 + + quickmpc + quickmpc.pandas + quickmpc.proto + quickmpc.request + quickmpc.share + quickmpc.utils diff --git a/docs/libclient-py/source/reference/quickmpc.pandas.rst b/docs/libclient-py/source/reference/quickmpc.pandas.rst new file mode 100644 index 000000000..c60a3adc8 --- /dev/null +++ b/docs/libclient-py/source/reference/quickmpc.pandas.rst @@ -0,0 +1,16 @@ +.. _quickmpc.pandas: + +.. module:: quickmpc.pandas + +quickmpc.pandas +=============== + +quickmpc.pandas is quickmpc.pandas. + +.. autosummary:: + :toctree: _generated/ + :nosignatures: + + quickmpc.pandas.ShareDataFrame + quickmpc.pandas.parse + quickmpc.pandas.read_csv diff --git a/docs/libclient-py/source/reference/quickmpc.proto.rst b/docs/libclient-py/source/reference/quickmpc.proto.rst new file mode 100644 index 000000000..d2c3d3b1a --- /dev/null +++ b/docs/libclient-py/source/reference/quickmpc.proto.rst @@ -0,0 +1,14 @@ +.. module:: quickmpc.proto + +quickmpc.proto +============== + +quickmpc.proto is quickmpc.proto + +.. autosummary:: + :toctree: _generated/ + :nosignatures: + + quickmpc.proto.common_types.common_types_pb2.JobErrorInfo + quickmpc.proto.common_types.common_types_pb2.JobStatus + quickmpc.proto.common_types.common_types_pb2.JobProgress diff --git a/docs/libclient-py/source/reference/quickmpc.request.rst b/docs/libclient-py/source/reference/quickmpc.request.rst new file mode 100644 index 000000000..65075efa7 --- /dev/null +++ b/docs/libclient-py/source/reference/quickmpc.request.rst @@ -0,0 +1,13 @@ +.. module:: quickmpc.request + +quickmpc.request +================ + +quickmpc.request is quickmpc.request. + +.. autosummary:: + :toctree: _generated/ + :nosignatures: + + quickmpc.request.QMPCRequest + quickmpc.request.QMPCRequestInterface diff --git a/docs/libclient-py/source/reference/quickmpc.rst b/docs/libclient-py/source/reference/quickmpc.rst new file mode 100644 index 000000000..1d3d3d062 --- /dev/null +++ b/docs/libclient-py/source/reference/quickmpc.rst @@ -0,0 +1,18 @@ +.. _quickmpc: + +.. module:: quickmpc + +quickmpc +======== + +quickmpc is quickmpc. + +.. autosummary:: + :toctree: _generated/ + :nosignatures: + + quickmpc.ArgumentError + quickmpc.QMPC + quickmpc.QMPCJobError + quickmpc.QMPCServerError + quickmpc.get_logger diff --git a/docs/libclient-py/source/reference/quickmpc.share.rst b/docs/libclient-py/source/reference/quickmpc.share.rst new file mode 100644 index 000000000..2fb0efa6e --- /dev/null +++ b/docs/libclient-py/source/reference/quickmpc.share.rst @@ -0,0 +1,13 @@ +.. module:: quickmpc.share + +quickmpc.share +=============== + +quickmpc.share is quickmpc.share + +.. autosummary:: + :toctree: _generated/ + :nosignatures: + + quickmpc.share.Share + quickmpc.share.restore diff --git a/docs/libclient-py/source/reference/quickmpc.utils.rst b/docs/libclient-py/source/reference/quickmpc.utils.rst new file mode 100644 index 000000000..d550fab7d --- /dev/null +++ b/docs/libclient-py/source/reference/quickmpc.utils.rst @@ -0,0 +1,19 @@ +.. module:: quickmpc.utils + +quickmpc.utils +============== + +quickmpc.utils is quickmpc.utils + +.. autosummary:: + :toctree: _generated/ + :nosignatures: + + quickmpc.utils.Dim1 + quickmpc.utils.Dim2 + quickmpc.utils.Dim3 + quickmpc.utils.DictList + quickmpc.utils.DictList2 + quickmpc.utils.MakePiece + quickmpc.utils.if_present + quickmpc.utils.methoddispatch diff --git a/packages/client/libclient-py/pyproject.toml b/packages/client/libclient-py/pyproject.toml index beb044643..3cb5d013d 100644 --- a/packages/client/libclient-py/pyproject.toml +++ b/packages/client/libclient-py/pyproject.toml @@ -2,6 +2,32 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] build-backend = "setuptools.build_meta" +[project] +name = "quickmpc" +license = {file = "LICENSE"} +authors = [ {name = "Acompany"} ] +requires-python = ">=3.7" +dependencies = [ + "numpy", + "grpcio-tools", + "grpcio-status", + "pynacl", + "tqdm", + "natsort", + "pandas", +] +dynamic = ["version"] + +[project.optional-dependencies] +document = [ + "sphinx", + "sphinx-autobuild", + "sphinx-rtd-theme", +] + +[tool.setuptools] +package-dir = {"quickmpc" = "quickmpc"} + [tool.setuptools_scm] write_to = "packages/client/libclient-py/quickmpc/_version.py" -root = "../../../" \ No newline at end of file +root = "../../../" diff --git a/packages/client/libclient-py/quickmpc/README.md b/packages/client/libclient-py/quickmpc/README.md deleted file mode 100644 index 182ce428d..000000000 --- a/packages/client/libclient-py/quickmpc/README.md +++ /dev/null @@ -1,155 +0,0 @@ -# QuickMPC-libClient-pyが提供する機能 - -## QMPC.send_share_from_csv_file -csvファイルからテーブルデータを読み込みShare化して送信する -### Parameters -- file: `str` - - 読み込むファイル名 - -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["data_id"]: `str` - - テーブルデータのID - -## QMPC.send_share_from_csv_data -テーブルデータをShare化して送信する -### Parameters -- data: `List[List[str]]` - - パースするテーブルデータ - -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["data_id"]: `str` - - テーブルデータのID - -## QMPC.delete_share -エンジンに保存されたテーブルデータの削除する -### Parameters -- data_ids: `List[str]` - - 削除するdata_idのリスト -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - -## QMPC.mean -平均値を計算する -### Parameters -- data_ids: `List[str]` - - data_idのリスト -- src: `List[int]` - - 平均値を計算する列リスト -- debug_mode: `bool` - - keyword引数.`True`の場合はdebug用の違法高速マッチングを行う -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["job_uuid"]: `str` - - 計算スレッドのUUID - -## QMPC.variance -分散を計算する -### Parameters -- data_ids: `List[str]` - - data_idのリスト -- src: `List[int]` - - 分散を計算する列リスト -- debug_mode: `bool` - - keyword引数.`True`の場合はdebug用の違法高速マッチングを行う -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["job_uuid"]: `str` - - 計算スレッドのUUID - -## QMPC.sum -総和を計算する -### Parameters -- data_ids: `List[str]` - - data_idのリスト -- src: `List[int]` - - 総和を計算する列リスト -- debug_mode: `bool` - - keyword引数.`True`の場合はdebug用の違法高速マッチングを行う -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["job_uuid"]: `str` - - 計算スレッドのUUID - -## QMPC.correl -相関係数を計算する -### Parameters -- data_ids: `List[str]` - - data_idのリスト -- inp: `Tuple[List[int], List[int]]` - - inp[0]: 相関係数の左列リスト - - inp[1]: 相関係数の右列リスト -- debug_mode: `bool` - - keyword引数.`True`の場合はdebug用の違法高速マッチングを行う -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["job_uuid"]: `str` - - 計算スレッドのUUID - -## QMPC.meshcode -メッシュコードを計算する -### Parameters -- data_ids: `List[str]` - - data_idのリスト -- src: `List[int]` - - メッシュコードを計算する列リスト -- debug_mode: `bool` - - keyword引数.`True`の場合はdebug用の違法高速マッチングを行う -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["job_uuid"]: `str` - - 計算スレッドのUUID - -## QMPC.get_join_table -テーブルを結合する -### Parameters -- data_ids: `List[str]` - - data_idのリスト -- debug_mode: `bool` - - keyword引数.`True`の場合はdebug用の違法高速マッチングを行う -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["job_uuid"]: `str` - - 計算スレッドのUUID - -## QMPC.get_computation_result -計算結果を取得する -### Parameters -- job_uuid: `str` - - 計算スレッドのUUID -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["results"]: `List[float]` - - 計算結果 - -## QMPC.get_data_list -[deplecated] -エンジンに保存されているテーブルデータのIDを全て取得する -### Parameters -### Returns -- res: `Dict` - - res["is_ok"]: `bool` - - 送信が成功したかどうか - - res["results"]: `List[str]` - - テーブルデータのIDのリスト diff --git a/packages/client/libclient-py/quickmpc/__init__.py b/packages/client/libclient-py/quickmpc/__init__.py index 1f1d8d52f..993370100 100644 --- a/packages/client/libclient-py/quickmpc/__init__.py +++ b/packages/client/libclient-py/quickmpc/__init__.py @@ -9,6 +9,5 @@ "QMPC", "QMPCJobError", "QMPCServerError", - "ShareDataFrame", "get_logger", ] diff --git a/packages/client/libclient-py/quickmpc/exception.py b/packages/client/libclient-py/quickmpc/exception.py index e6dfa358f..dd3d2cdd8 100644 --- a/packages/client/libclient-py/quickmpc/exception.py +++ b/packages/client/libclient-py/quickmpc/exception.py @@ -6,13 +6,29 @@ @dataclass(frozen=True) class QMPCJobError(Exception): + """MPCの計算中にエラーが発生した場合のエラー + + Attributes + ---------- + err_info: Union[str, JobErrorInfo] + エラーの詳細 + """ err_info: Union[str, JobErrorInfo] @dataclass(frozen=True) class QMPCServerError(Exception): + """QMPCサーバで何らかの異常が発生した場合エラー + + Attributes + ---------- + err_info: Union[str, JobErrorInfo] + エラーの詳細 + """ err_info: Union[str, JobErrorInfo] class ArgumentError(Exception): + """想定していない引数が与えられた時のエラー""" + # TODO: NotImplementedErrorに置き換えられるので削除する pass diff --git a/packages/client/libclient-py/quickmpc/pandas/parser.py b/packages/client/libclient-py/quickmpc/pandas/parser.py index dad49fd6e..a54a4dc52 100644 --- a/packages/client/libclient-py/quickmpc/pandas/parser.py +++ b/packages/client/libclient-py/quickmpc/pandas/parser.py @@ -23,34 +23,90 @@ @dataclass(frozen=True) class FormatChecker: + """Formatが正しいかをチェックする + """ @methoddispatch(is_static_method=True) @staticmethod - def check_duplicate(_): + def check_duplicate(*args, **kw): raise ArgumentError("不正な引数が与えられています.") - @check_duplicate.register(Dim1) @staticmethod - def check_duplicate_dummy(schema: List[str]): + @check_duplicate.register(Dim1) + def check_duplicate_dummy(*args, **kw): raise ArgumentError("不正な引数が与えられています.") @check_duplicate.register((Dim1, str)) @staticmethod - def check_duplicate_strs(schema: List[str]) -> bool: - return len(schema) == len(set(schema)) + def check_duplicate_strs(lst: List[str]) -> bool: + """文字列配列に重複要素があるかチェックする + + Parameters + ---------- + lst: List[str] + チェックする配列 + + Returns + ------- + bool + 違反してないかどうか + """ + return len(lst) == len(set(lst)) @check_duplicate.register((Dim1, Schema)) @staticmethod def check_duplicate_typed(schema: List[Schema]) -> bool: + """スキーマ配列に重複要素があるかチェックする + + スキーマ種別による区別はなく,スキーマの名前だけで重複チェックを行う + + Parameters + ---------- + schema: List[str] + チェックする配列 + + Returns + ------- + bool + 違反してないかどうか + """ return len(schema) == len(set([sch.name for sch in schema])) @staticmethod def check_size(secrets: Sequence[Sequence[Union[str, ShareValueType]]], schema: Sequence[Union[str, Schema]]) -> np.bool_: + """Schemaサイズとテーブルの列サイズが等しいかをチェックする + + Parameters + ---------- + secrets: Sequence[Sequence[Union[str, ShareValueType]]] + チェックするテーブルデータ + schema: Sequence[Union[str, Schema]]) + チェックするスキーマ + + Returns + ------- + np.bool_ + 違反してないかどうか + """ return np.all([len(s) == len(schema) for s in secrets]) def format_check(secrets: List[List[ShareValueType]], schema: Sequence[Union[str, Schema]]) -> bool: + """Schemaとテーブルデータが要件を満たすかチェックする + + Parameters + ---------- + secrets: List[List[ShareValueType]], + チェックするテーブルデータ + schema: Sequence[Union[str, Schema]]) + チェックするスキーマ + + Returns + ------- + bool + 違反してないかどうか + """ # 存在チェック if not (schema and secrets): logger.error("Schema or secrets table are not exists.") @@ -67,7 +123,32 @@ def format_check(secrets: List[List[ShareValueType]], def to_float(val: str) -> float: - """ If val is a float, convert as is; if it is a string, hash it. """ + """stringをfloatに変換する + + stringが数値に変換可能な場合は数値に変換する. + 変換できない場合は512bitのhashに変換した上で上位k bitを切り取り, + 整数部(k-m) bit,小数部m bitの値をfloatに変換して返す. + ただし,k,mはMPCサーバ内部のLTZ(Less Than Zero)プロトコルで使用されるハイパーパラメータである. + + Parameters + ---------- + val: str + 変換する文字列 + + Returns + ------- + float + 変換された実数 + + Examples + --------- + >>> to_float("1") + 1.0 + >>> to_float("1.01") + 1.01 + >>> to_float("string") + 41254067.792910576 + """ try: return float(val) except ValueError: @@ -83,11 +164,49 @@ def to_float(val: str) -> float: def to_int(val: str, encoding='utf-8') -> int: + """stringをintに変換する + + 文字列を `encoding` の形式で数値に変換する. + 元々数値に変換できる文字列だったとしても全て `encoding` の形式に変換する. + + Parameters + ---------- + val: str + 変換する文字列 + encoding: str, default="utf-8" + 変換形式 + + Returns + ------- + int + 変換された実数 + + Examples + --------- + >>> to_int("1") + 49 + >>> to_int("1.01") + 825110577 + >>> to_int("string") + 126943972912743 + """ encoded = val.encode(encoding) return int.from_bytes(encoded, byteorder='big') def check_float_data(val: str) -> bool: + """文字列がfloatに変換可能であるかチェックする + + Parameters + ---------- + val: str + チェックする文字列 + + Returns + ------- + bool + 変換可能かどうか + """ try: _ = float(val) return True @@ -98,6 +217,36 @@ def check_float_data(val: str) -> bool: def find_type(col_schema: str, col_data: List[str], is_matching_column: bool) \ -> ShareValueTypeEnum.ValueType: + """スキーマにタグを検索する + + スキーマ名に `:` という形式のタグがついている場合はこのタグを返す. + ついていない場合はテーブルデータの文字列から推測する. + + Parameters + ---------- + col_schema: str, + スキーマの名前 + col_data: List[str], + 列の全てのデータ + is_matching_column: bool) + その列がID列かどうか + + Returns + ------- + ShareValueTypeEnum.ValueType + スキーマのタグ + + Examples + -------- + >>> find_type("s:id", ["a"], True) + ShareValueTypeEnum.SHARE_VALUE_TYPE_FIXED_POINT + >>> find_type("s", ["a"], True) + ShareValueTypeEnum.SHARE_VALUE_TYPE_FIXED_POINT + >>> find_type("s", ["a"], False) + ShareValueTypeEnum.SHARE_VALUE_TYPE_UTF_8_INTEGER_REPRESENTATION + >>> find_type("s", ["100"], False) + ShareValueTypeEnum.SHARE_VALUE_TYPE_FIXED_POINT + """ # check tag type_str, *remains = col_schema.split(':') @@ -124,6 +273,22 @@ def find_types(schema: List[str], data: List[List[str]], matching_column: Optional[int] = None ) -> List[ShareValueTypeEnum.ValueType]: + """各スキーマのタグを検索する + + Parameters + ---------- + schema: List[str] + テーブルデータのスキーマ + data: List[List[str]] + テーブルデータ + matching_column: Optional[int], default=None + ID列のindex(1-index) + + Returns + ------- + List[ShareValueTypeEnum.ValueType] + 各スキーマのタグ + """ # transpose to get column oriented list transposed: Iterable[List[str]] = map(list, zip(*data)) return [find_type(sch, col, idx == matching_column) @@ -132,6 +297,20 @@ def find_types(schema: List[str], def convert(element: str, type_info: ShareValueTypeEnum.ValueType) -> ShareValueType: + """スキーマのタグに沿って文字列を数値に変換する + + Parameters + ---------- + element: str + 変換する文字列 + type_info: ShareValueTypeEnum.ValueType + スキーマタグ + + Returns + ------- + ShareValueType + 変換後の値 + """ if type_info == ShareValueTypeEnum.Value('SHARE_VALUE_TYPE_FIXED_POINT'): return to_float(element) if type_info == ShareValueTypeEnum.Value( @@ -142,6 +321,20 @@ def convert(element: str, def parse(data: List[List[str]], matching_column: Optional[int] = None) \ -> Tuple[List[List[ShareValueType]], List[Schema]]: + """テーブルデータをMPCで使用できる形式にparseする + + Parameters + ---------- + data: List[List[str]] + テーブルデータ + matching_column: Optional[int], default=None + ID列のindex(1-index) + + Returns + ------- + Tuple[List[List[ShareValueType]], List[Schema]]: + parseしたスキーマとテーブルデータ + """ schema_name: List[str] = data[0] types = find_types(schema_name, data[1:], matching_column) schema = [Schema(name=name, type=type) @@ -165,6 +358,20 @@ def parse(data: List[List[str]], matching_column: Optional[int] = None) \ def parse_csv( filename: str, matching_column: Optional[int] = None) \ -> Tuple[List[List[ShareValueType]], List[Schema]]: + """csvのテーブルデータをMPCで使用できる形式にparseする + + Parameters + ---------- + filename: str + 入力ファイルのpath + matching_column: Optional[int], default=None + ID列のindex(1-index) + + Returns + ------- + Tuple[List[List[ShareValueType]], List[Schema]]: + parseしたスキーマとテーブルデータ + """ with open(filename) as f: reader = csv.reader(f) text: List[List[str]] = [row for row in reader] diff --git a/packages/client/libclient-py/quickmpc/pandas/progress.py b/packages/client/libclient-py/quickmpc/pandas/progress.py index 293ef7454..2a66c433d 100644 --- a/packages/client/libclient-py/quickmpc/pandas/progress.py +++ b/packages/client/libclient-py/quickmpc/pandas/progress.py @@ -10,12 +10,54 @@ @dataclass class Progress: + """計算の進捗を管理するクラス + + Attributes + ---------- + pbars: OrderedDict[Tuple[int, int], + Tuple[Optional[tqdm.tqdm], float]] + 計算Statusを保持する辞書 + + Examples + -------- + .. code-block:: python3 + + progress = Progress() + while runinng: + status = get_status() + progress.update(status) + + output + + .. code-block:: console + + [0] status: 100%|███████████████| 6/6 [00:01<00:00, 5.88it/s, status=COMPLETED] + [1] status: 100%|███████████████| 6/6 [00:01<00:00, 5.88it/s, status=COMPLETED] + [2] status: 100%|███████████████| 6/6 [00:01<00:00, 5.88it/s, status=COMPLETED] + [0] hjoin: core: 100%|█████████| 100.0/100 [00:01<00:00, 97.96it/s, details=1/4] + [0] hjoin: binary search: 100%|█| 100.0/100 [00:01<00:00, 97.97it/s, details=0/2 + [1] hjoin: core: 100%|█████████| 100.0/100 [00:01<00:00, 97.98it/s, details=1/4] + [1] hjoin: binary search: 100%|█| 100.0/100 [00:01<00:00, 97.99it/s, details=0/2 + [2] hjoin: core: 100%|█████████| 100.0/100 [00:01<00:00, 98.00it/s, details=1/4] + [2] hjoin: binary search: 100%|█| 100.0/100 [00:01<00:00, 98.01it/s, details=0/2 + """ # noqa: E501 pbars: OrderedDict[Tuple[int, int], Tuple[Optional[tqdm.tqdm], float]] \ = field(default_factory=collections.OrderedDict) # TODO: statusとprogreassを受け取るようにする(疎にするため) def update(self, res: GetComputationStatusResponse): + """進捗Statusを更新する + + Parameters + ---------- + res: GetComputationStatusResponse + 計算Status + + Returns + ------- + None + """ if res.job_statuses is not None: for party_id, status in enumerate(res.job_statuses): key = (party_id, -1) diff --git a/packages/client/libclient-py/quickmpc/pandas/readers.py b/packages/client/libclient-py/quickmpc/pandas/readers.py index 07f408081..ce197dd65 100644 --- a/packages/client/libclient-py/quickmpc/pandas/readers.py +++ b/packages/client/libclient-py/quickmpc/pandas/readers.py @@ -1,26 +1,32 @@ import pandas as pd +from pandas.core.shared_docs import _shared_docs +from pandas.io.parsers.readers import _doc_read_csv_and_table +from pandas.util._decorators import Appender from quickmpc.pandas.parser import to_float +# pandas.read_csvからdocumentを持ってくる +@Appender( + _doc_read_csv_and_table.format( + func_name="read_csv", + summary="Read a comma-separated values (csv) file into DataFrame.", + _default_sep="','", + storage_options=_shared_docs["storage_options"], + decompression_options="", # TODO: 取得できなかったので原因を特定する + # decompression_options=_shared_docs["decompression_options"] + # % "filepath_or_buffer", + ) +) def read_csv(*args, index_col: str, **kwargs) -> pd.DataFrame: - """csvからテーブルデータを読み込む. + """csvからテーブルデータを読み込む - テーブル結合処理に用いる列がどの列かを`index_col`で指定する必要がある. - `index_col`以外の引数は全てpandasのread_csvと同じ. + テーブル結合処理に用いる列がどの列かを `index_col` で列名を指定する必要がある. + `index_col` 以外の引数は全てpandasのread_csvと同じ. + 以下のdocumentは `pandasのdocument `_ より. - Parameters - ---------- - filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], - らしい - index_col: str - ID列としたいカラム名 - - Returns - ---------- - pd.DataFrame - 読み込んだテーブルデータ - """ + .. docにはpandasからコピーしたdocumentが表示される + """ # noqa: E501 df = pd.read_csv(*args, **kwargs) # ID列を数値化 df[index_col] = df[index_col].map(lambda x: to_float(x)) diff --git a/packages/client/libclient-py/quickmpc/pandas/share_data_frame.py b/packages/client/libclient-py/quickmpc/pandas/share_data_frame.py index fab6230dc..5de4061a9 100644 --- a/packages/client/libclient-py/quickmpc/pandas/share_data_frame.py +++ b/packages/client/libclient-py/quickmpc/pandas/share_data_frame.py @@ -18,6 +18,8 @@ class ShareDataFrameStatus(Enum): + """Jobの計算Status + """ OK = 1 EXECUTE = 2 ERROR = 3 @@ -49,16 +51,16 @@ def wrapper(self: "ShareDataFrame", *args, class ShareDataFrame: """テーブルデータを管理するクラス - Attributes - ---------- - __id: str + Args + ---- + _ShareDataFrame__id: str データのID - __qmpc_request: quickmpc.request.QMPCClientInterface + _ShareDataFrame__qmpc_request: quickmpc.request.QMPCClientInterface QuickMPCとの通信を担うClient - __is_result: bool + _ShareDataFrame__is_result: bool send由来のDataFrame(False)なのかexecute由来のDataFrameなのか(True) - __status: ShareDataFrameStatus - 現在の状態 + _ShareDataFrame__status: ShareDataFrameStatus + 現在のデータの状態 """ __id: str @@ -67,6 +69,27 @@ class ShareDataFrame: __status: ShareDataFrameStatus = ShareDataFrameStatus.OK def _wait_execute(self, progress: bool) -> None: + """計算が終了するまで待機する + + 管理しているIDがjob_uuid(execute由来のID)である場合,計算が終了するまで待機する. + send_shareしたdata_idを管理している場合は待機する必要がないためすぐにreturnする. + + 待機は計算終了まで永遠に行われ,1秒おきにMPCサーバに現在Statusの確認リクエストが送信される. + + Parameters + ---------- + progress: bool + 計算Statusをログに出力するかどうか + + Returns + ---------- + None + + Raises + ------ + QMPCJobError + 計算中に何らかのエラーが発生している時 + """ if self.__status == ShareDataFrameStatus.ERROR: raise QMPCJobError("ShareDataFrame's status is `ERROR`") if self.__status == ShareDataFrameStatus.EXECUTE: @@ -92,19 +115,19 @@ def _wait_execute(self, progress: bool) -> None: time.sleep(1) def __add__(self, other: "ShareDataFrame") -> "ShareDataFrame": - """テーブルを加算する. + """テーブルデータを加算する. - qmpc.send_toで送ったデータでかつ,行数,列数が一致している場合のみ正常に動作する. + :any:`quickmpc.send_to` で送ったデータでかつ,行数,列数が一致している場合のみ正常に動作する. Parameters ---------- other: ShareDataFrame - 結合したいDataFrame + 加算したいDataFrame Returns ---------- - Result - 加算して得られたDataFrameのResult + ShareDataFrame + 加算して得られた :class:`ShareDataFrame` """ res = self.__qmpc_request.add_share_data_frame(self.__id, other.__id) return ShareDataFrame(res.data_id, self.__qmpc_request) @@ -114,17 +137,19 @@ def join(self, other: "ShareDataFrame", *, debug_mode=False) \ -> "ShareDataFrame": """テーブルデータを結合する. - inner_joinのみ. + 内部ではinner_joinを行う. Parameters ---------- other: ShareDataFrame 結合したいDataFrame + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか Returns ---------- - Result - 結合したDataFrameのResult + result: ShareDataFrame + 結合して得られた :class:`ShareDataFrame` """ return self.join([other], debug_mode=debug_mode) @@ -133,45 +158,121 @@ def join_list(self, others: List["ShareDataFrame"], *, debug_mode=False)\ -> "ShareDataFrame": """テーブルデータを結合する. - inner_joinのみ. + 内部ではinner_joinを行う. Parameters ---------- others: List[ShareDataFrame] 結合したいDataFrameのリスト + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか Returns ---------- - Result - 結合したDataFrameのResult + ShareDataFrame + 結合して得られた :class:`ShareDataFrame` """ res = self.__qmpc_request.join([self.__id] + [o.__id for o in others], debug_mode=debug_mode) return ShareDataFrame(res.job_uuid, self.__qmpc_request, True, ShareDataFrameStatus.EXECUTE) - def sum(self, columns: list) -> "ShareDataFrame": + def sum(self, columns: List[int]) -> "ShareDataFrame": + """列の総和を取得する + + Parameters + ---------- + columns: List[int] + 計算に用いる列番号(1-index) + + Returns + ---------- + ShareDataFrame + 結果のDataFrame + """ res = self.__qmpc_request.sum([self.__id], columns) return ShareDataFrame(res.job_uuid, self.__qmpc_request, True, ShareDataFrameStatus.EXECUTE) - def mean(self, columns: list) -> "ShareDataFrame": + def mean(self, columns: List[int]) -> "ShareDataFrame": + """指定した列の平均を取得する + + 結果として得られる行列の `i` 行目には + 入力テーブルの `columns[i]` 列目の平均が入る. + + Parameters + ---------- + columns: List[int] + 計算に用いる列番号(1-index) + + Returns + ---------- + ShareDataFrame + `len(columns)` 行 1列行列の :class:`ShareDataFrame` + """ res = self.__qmpc_request.mean([self.__id], columns) return ShareDataFrame(res.job_uuid, self.__qmpc_request, True, ShareDataFrameStatus.EXECUTE) - def variance(self, columns: list) -> "ShareDataFrame": + def variance(self, columns: List[int]) -> "ShareDataFrame": + """指定した列の分散を取得する + + 結果として得られる行列の `i` 行目には + 入力テーブルの `columns[i]` 列目の分散が入る. + + Parameters + ---------- + columns: List[int] + 計算に用いる列番号(1-index) + + Returns + ---------- + ShareDataFrame + `len(columns)` 行 1列行列の :class:`ShareDataFrame` + """ res = self.__qmpc_request.variance([self.__id], columns) return ShareDataFrame(res.job_uuid, self.__qmpc_request, True, ShareDataFrameStatus.EXECUTE) - def correl(self, columns1: list, columns2: list) -> "ShareDataFrame": + def correl(self, columns1: List[int], columns2: List[int]) \ + -> "ShareDataFrame": + """指定した列同士の相関係数を取得する + + 結果として得られる行列の `i` 行 `j` 列目には + 入力テーブルの `columns1[i]` 列目と `columns2[j]` 列目の相関係数が入る. + + Parameters + ---------- + columns1: List[int] + 計算に用いる左項の列番号(1-index) + columns2: List[int] + 計算に用いる右項の列番号(1-index) + + Returns + ---------- + ShareDataFrame + `len(columns1)` 行 `len(columns2)` 列行列の :class:`ShareDataFrame` + """ res = self.__qmpc_request.correl([self.__id], columns1, columns2) return ShareDataFrame(res.job_uuid, self.__qmpc_request, True, ShareDataFrameStatus.EXECUTE) - def meshcode(self, columns: list) \ - -> "ShareDataFrame": + def meshcode(self, columns: List[int]) -> "ShareDataFrame": + """指定した列のmeshcodeを取得する + + 結果として得られる行列の `i` 行目には + 入力テーブルの `columns[i]` 列目のmeshcodeが入る. + + Parameters + ---------- + columns: List[int] + 計算に用いる列番号(1-index) + + Returns + ---------- + ShareDataFrame + `len(columns)` 行 1列行列の :class:`ShareDataFrame` + """ res = self.__qmpc_request.meshcode([self.__id], columns) return ShareDataFrame(res.job_uuid, self.__qmpc_request, True, ShareDataFrameStatus.EXECUTE) @@ -189,6 +290,12 @@ def to_csv(self, output_path: str) -> None: Returns ---------- + None + + Raises + ------ + RuntimeError + 送信したデータをそのまま保存しようとした場合 """ # 計算結果でないなら保存されないようにする if not self.__is_result: @@ -208,6 +315,11 @@ def to_data_frame(self) -> pd.DataFrame: ---------- pd.DataFrame 計算結果 + + Raises + ------ + RuntimeError + 送信したデータをそのまま取得しようとした場合 """ # 計算結果でないなら取得できないようにする if not self.__is_result: @@ -244,8 +356,13 @@ def get_elapsed_time(self) -> List[float]: Returns ---------- - float + List[float] 計算時間 + + Raises + ------ + RuntimeError + 計算していないデータを指定した場合 """ # 計算していない場合は計算時間が存在しない if not self.__is_result: diff --git a/packages/client/libclient-py/quickmpc/qmpc.py b/packages/client/libclient-py/quickmpc/qmpc.py index 8a581a9b1..061160345 100644 --- a/packages/client/libclient-py/quickmpc/qmpc.py +++ b/packages/client/libclient-py/quickmpc/qmpc.py @@ -17,11 +17,22 @@ class QMPC: Attributes ---------- - arg: Union[List[str], QMPCRequestInterface]] - parties: List[str] - serverのIP - qmpc_request: QMPCRequestMock - qmpc serverに対してrequestを送るinterface + __qmpc_request: QMPCRequestInterface + qmpc serverに対してrequestを送るinterface + + Examples + -------- + .. code-block:: python3 + + # 直接IPを指定する場合 + qmpc = QMPC([ + "http://localhost:50001", + "http://localhost:50002", + "http://localhost:50003" + ]) + # mockなどでrequestクラスを指定する場合 + request = QMPCRequest([~~~]) + qmpc = QMPC(request) """ arg: InitVar[Union[List[str], QMPCRequestInterface]] @@ -43,27 +54,47 @@ def __post_init__original(self, qmpc_request: QMPCRequest): object.__setattr__(self, "_QMPC__qmpc_request", qmpc_request) def send_to(self, df: pd.DataFrame) -> qpd.ShareDataFrame: - """QuickMPCサーバにデータを送信する. + """QuickMPCサーバにデータを送信する - `quickmpc.pandas.read_csv()`により読み込んだデータをQuickMPCサーバに送信する. - `quickmpc.pandas.read_csv()`では``__qmpc_sort_index__``と呼ばれるデータの順序を保持した列がcolumnsに追加されており,send_shareではこの列があることを要求している. - `quickmpc.pandas.read_csv()`を経由せずにsend_shareする場合は,あらかじめデータ順序を求めて``__qmpc_sort_index__``列を追加しておく必要がある. + :doc:`quickmpc.pandas.read_csv` により読み込んだデータをQuickMPCサーバに送信する. + :doc:`quickmpc.pandas.read_csv` では ``__qmpc_sort_index__`` と呼ばれるデータの順序を保持した列がcolumnsに追加されており, + このメソッドではこの列があることを要求している. + :doc:`quickmpc.pandas.read_csv` を経由せずにQuickMPCサーバにデータを送信する場合は, + あらかじめデータ順序を求めて ``__qmpc_sort_index__`` 列を追加しておく必要がある. Parameters ---------- - df: df.DataFrame + df: pandas.DataFrame 送信するデータ Returns ---------- quickmpc.pandas.ShareDataFrame QuickMPC形式のDataframe - """ + + Examples + -------- + .. code-block:: python3 + + # quickmpc.pandas.read_csv で読み込んだデータを送信する + df = qpd.read_csv("filepath") + qmpc.send_to(df) + + # 自分で定義したデータを送信する + qmpc.send_to(df) + df = pd.DataFrame([[1,2,0], [3,4,1]], + columns=["a", "b", "__qmpc_sort_index__"])), + + See Also + -------- + quickmpc.pandas.read_csv + quickmpc形式に則る専用のcsv読み取り関数 + """ # noqa #501 res = self.__qmpc_request.send_share(df, piece_size=1_000_000) return qpd.ShareDataFrame(res.data_id, self.__qmpc_request) def load_from(self, id_str: str) -> qpd.ShareDataFrame: - """既に送信してあるデータを参照する. + """QuickMPCサーバに送信してあるデータを参照する Parameters ---------- @@ -71,9 +102,19 @@ def load_from(self, id_str: str) -> qpd.ShareDataFrame: 既に送信してあるデータのIDか計算で発行したID Returns - ---------- + ------- quickmpc.pandas.ShareDataFrame QuickMPC形式のDataframe + + Raises + ------ + RuntimeError + 形式が異なるIDを指定した場合 + + See Also + -------- + quickmpc.pandas.ShareDataFrame.get_id + 参照するIDを取得するためのメソッド """ # data_id (256bit hash) if re.fullmatch(r'[a-z0-9]{64}', id_str): @@ -88,7 +129,7 @@ def load_from(self, id_str: str) -> qpd.ShareDataFrame: # TODO: job_uuidとparty_sizeは指定しなくても良いようにしたい def restore(self, job_uuid: str, filepath: str, party_size: int) \ -> qpd.ShareDataFrame: - """既に送信してあるデータを参照する. + """ファイルに保存したMPC結果を復元して取得する. Parameters ---------- @@ -96,11 +137,33 @@ def restore(self, job_uuid: str, filepath: str, party_size: int) \ データのID filepath: str dataが保存してあるディレクトリ + party_size: int + MPCのパーティ数 Returns - ---------- + ------- quickmpc.pandas.ShareDataFrame QuickMPC形式のDataframe + + Examples + -------- + .. code-block:: python3 + + # 計算結果をファイルに保存する + filepath: str # 保存するディレクトリ + sdf: qpd.ShareDataFrame # 計算結果 + sdf.to_csv(filepath) # 結果を保存 + + # 計算結果をファイルから復元する + job_uuid = sdf.get_id() + sdf = qmpc.restore(job_uuid, filepath, 3) + + See Also + -------- + quickmpc.pandas.ShareDataFrame.get_id + 参照するIDを取得するためのメソッド + quickmpc.pandas.ShareDataFrame.to_csv + 計算結果をファイルに保存するメソッド.restoreはこのメソッドで保存したファイル群に対して適用できる. """ # TODO: get_computation_resultと同じ処理なのでうまくまとめる res = restore(job_uuid, filepath, party_size) diff --git a/packages/client/libclient-py/quickmpc/request/qmpc_request.py b/packages/client/libclient-py/quickmpc/request/qmpc_request.py index 88112b8db..14a5587b7 100644 --- a/packages/client/libclient-py/quickmpc/request/qmpc_request.py +++ b/packages/client/libclient-py/quickmpc/request/qmpc_request.py @@ -44,6 +44,27 @@ def _create_grpc_channel(endpoint: str) -> grpc.Channel: + """gppcのchannnelを生成する + + 入力形式をパースして自動で適切なチャンネルを生成する. + 現在サポートしているのはhttpとhttpsの2つで, + いずれも `http://~~~` という文字列である必要がある. + + Parameters + ---------- + endpoint: str + endpointのIP + + Returns + -------- + grpc.channel + grpcのchannel + + Raises + ------- + ArgumentError + サポートしていないプロトコルを指定した場合 + """ channel: grpc.Channel = None o = urlparse(endpoint) if o.scheme == 'http': @@ -66,12 +87,29 @@ def _create_grpc_channel(endpoint: str) -> grpc.Channel: class QMPCRequest(QMPCRequestInterface): """QuickMPCサーバと通信を行う + Args + ---- + endpoints: List[str] + 各PartyのIP + _QMPCRequest__retry_num: int + 通信のretry回数 + _QMPCRequest__retry_wait_time : int + 通信のretry待機時間 + Attributes ---------- - __endpoints: List[url] - QuickMPCサーバのURL + __client_stubs: Tuple[LibcToManageStub] + 各Partyのgrpcのstub + __client_channels: Tuple[grpc.Channel] + 各Partyのgrpcのchannnel + __party_size: int + Partyの数 __token: str - QuickMPCサーバへの通信を担う + MPCサーバへの通信のtoken + __retry_num: int + 通信のretry回数 + __retry_wait_time: int + 通信のretry待機時間 """ endpoints: InitVar[List[str]] @@ -95,6 +133,27 @@ def __post_init__(self, endpoints: List[str]) -> None: object.__setattr__(self, "_QMPCRequest__token", token) def __retry(self, f: Callable, *request: Any) -> Any: + """リトライポリシーに従って特定の関数を実行する + + Parameters + ---------- + f: Callable + 実行するgrpc request + request: Any + fに渡す全ての引数 + + Returns + ------- + Any + f(*request)の実行結果 + + Raises + ------ + QMPCServerError + MPCサーバへの接続が確立できないなどサーバのエラーが起きた場合 + QMPCJobError + 計算中に何らかのエラーが発生していた場合 + """ for ch in self.__client_channels: # channelの接続チェック is_channel_ready = False @@ -143,7 +202,25 @@ def __retry(self, f: Callable, *request: Any) -> Any: @staticmethod def __futures_result( futures: Iterable, enable_progress_bar=True) -> List: - """ エラーチェックしてfutureのresultを得る """ + """進捗ログを出しながらfutureの結果を取得する + + Parameters + ---------- + futures: concurrent.futures.Executor + 非同期実行している実体 + enable_progress_bar: bool, default=True + 進捗ログを出すかどうか + + Returns + ------- + List + 各futuresの返り値のリスト + + Raises + ------ + Exception + futures実行中の例外 + """ try: if enable_progress_bar: futures = tqdm.tqdm(futures, desc='receive') @@ -154,7 +231,25 @@ def __futures_result( def send_share(self, df: pd.DataFrame, piece_size: int = 1_000_000) \ -> SendShareResponse: - """ Shareをコンテナに送信 """ + """ShareをMPCサーバに送信する + + Parameters + ---------- + df: pd.DataFrame + 送信するデータ + piece_size: int, default=1_000_000 + 分割pieceの各size + + Returns + ------- + SendShareResponse + SendSharesRequestの結果 + + Raises + ------ + RuntimeError + 規定されていない引数を与えた場合 + """ if piece_size < 1000 or piece_size > 1_000_000: raise RuntimeError( "piece_size must be in the range of 1000 to 1000000") @@ -203,10 +298,30 @@ def send_share(self, df: pd.DataFrame, piece_size: int = 1_000_000) \ def __execute_computation(self, method_id: ComputationMethod.ValueType, data_ids: List[str], - columns: Tuple[List, List], + columns: Tuple[List[int], List[int]], *, debug_mode: bool = False) \ -> ExecuteResponse: - """ 計算リクエストを送信 """ + """計算リクエストをMPCサーバに送信する + + MPCサーバ上でJobと呼ばれる計算リクエストを発行する. + `data_ids` で指定したIDはサーバ内で全てinner_joinされる. + + Parameters + ---------- + method_id: ComputationMethod.ValueType + 計算の種類を管理するID + data_ids: List[str] + 計算に用いるデータのID + columns: Tuple[List[int], List[int]] + 計算に用いるデータの列番号(1-index) + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ req = ExecuteComputationRequest( method_id=method_id, token=self.__token, @@ -229,12 +344,44 @@ def __execute_computation(self, method_id: ComputationMethod.ValueType, def sum(self, data_ids: List[str], columns: List[int], *, debug_mode: bool = False) -> ExecuteResponse: + """総和計算リクエストをMPCサーバに送信する + + Parameters + ---------- + data_ids: List[str] + 計算に用いるデータのID + columns: List[int] + 計算に用いるデータの列番号(1-index) + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ return self.__execute_computation( ComputationMethod.COMPUTATION_METHOD_SUM, data_ids, (columns, []), debug_mode=debug_mode) def mean(self, data_ids: List[str], columns: List[int], *, debug_mode: bool = False) -> ExecuteResponse: + """平均計算リクエストをMPCサーバに送信する + + Parameters + ---------- + data_ids: List[str] + 計算に用いるデータのID + columns: List[int] + 計算に用いるデータの列番号(1-index) + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ return self.__execute_computation( ComputationMethod.COMPUTATION_METHOD_MEAN, data_ids, (columns, []), debug_mode=debug_mode) @@ -242,6 +389,22 @@ def mean(self, data_ids: List[str], columns: List[int], def variance(self, data_ids: List[str], columns: List[int], *, debug_mode: bool = False) \ -> ExecuteResponse: + """分散計算リクエストをMPCサーバに送信する + + Parameters + ---------- + data_ids: List[str] + 計算に用いるデータのID + columns: List[int] + 計算に用いるデータの列番号(1-index) + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ return self.__execute_computation( ComputationMethod.COMPUTATION_METHOD_VARIANCE, data_ids, (columns, []), debug_mode=debug_mode) @@ -249,6 +412,24 @@ def variance(self, data_ids: List[str], columns: List[int], def correl(self, data_ids: List[str], inp1: List[int], inp2: List[int], *, debug_mode: bool = False) \ -> ExecuteResponse: + """相関係数計算リクエストをMPCサーバに送信する + + Parameters + ---------- + data_ids: List[str] + 計算に用いるデータのID + inp1: List[int] + 相関係数の左項に該当する列番号(1-index) + inp2: List[int] + 相関係数の右項に該当する列番号(1-index) + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ return self.__execute_computation( ComputationMethod.COMPUTATION_METHOD_CORREL, data_ids, (inp1, inp2), debug_mode=debug_mode) @@ -256,12 +437,42 @@ def correl(self, data_ids: List[str], inp1: List[int], inp2: List[int], def meshcode(self, data_ids: List[str], inp1: List[int], *, debug_mode: bool = False) \ -> ExecuteResponse: + """meshcodeリクエストをMPCサーバに送信する + + Parameters + ---------- + data_ids: List[str] + 計算に用いるデータのID + inp1: List[int] + 計算に用いるデータの列番号(1-index) + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ return self.__execute_computation( ComputationMethod.COMPUTATION_METHOD_MESH_CODE, data_ids, (inp1, []), debug_mode=debug_mode) def join(self, data_ids: List[str], *, debug_mode: bool = False) -> ExecuteResponse: + """inner_joinリクエストをMPCサーバに送信する + + Parameters + ---------- + data_ids: List[str] + 計算に用いるデータのID + debug_mode: bool, default=False + 違法な高速結合による高速化をするかどうか + + Returns + ------- + ExecuteResponse + 計算リクエストの結果 + """ return self.__execute_computation( ComputationMethod.COMPUTATION_METHOD_JOIN_TABLE, data_ids, ([], []), debug_mode=debug_mode) @@ -269,7 +480,27 @@ def join(self, data_ids: List[str], @staticmethod def __stream_result(stream: Iterable, job_uuid: str, party: int, output_path: Optional[str]) -> Dict: - """ エラーチェックしてstreamのresultを得る """ + """streamから計算結果を取得して結合する + + `output_path` を指定すると計算結果をファイルに書き込み,返り値にデータ本体は含まれない. + `output_path` がNoneの場合は計算結果はファイルには書き込まれず,返り値に全て含まれる. + + Parameters + ---------- + stream: Iterable + grpcのstreamインスタンス + job_uuid: str + jobのID + party: int + Partyの番号(1-index) + output_path: Optional[str] + 出力するファイルのパス + + Returns + ------- + Dict + streamの結果を結合した辞書 + """ is_ok: bool = True res_list = [] for res in stream: @@ -298,7 +529,23 @@ def __stream_result(stream: Iterable, job_uuid: str, party: int, def get_computation_result(self, job_uuid: str, output_path: Optional[str] = None) \ -> GetResultResponse: - """ コンテナから結果を取得 """ + """計算結果をMPCサーバから取得する + + `output_path` を指定すると計算結果をファイルに書き込み,返り値にデータ本体は含まれない. + `output_path` がNoneの場合は計算結果はファイルには書き込まれず,返り値に全て含まれる. + + Parameters + ---------- + job_uuid + jobのID + output_path: Optional[str] + 出力するファイルのパス + + Returns + ------- + GetResultResponse + 計算結果取得リクエストの結果 + """ # リクエストパラメータを設定 req = GetComputationRequest( job_uuid=job_uuid, @@ -361,6 +608,18 @@ def get_computation_result(self, job_uuid: str, def get_computation_status(self, job_uuid: str) \ -> GetComputationStatusResponse: + """計算ステータスをMPCサーバから取得する + + Parameters + ---------- + job_uuid: str + jobのID + + Returns + ------- + GetComputationStatusResponse + 計算ステータス + """ # リクエストパラメータを設定 req = GetComputationRequest( job_uuid=job_uuid, @@ -381,6 +640,18 @@ def get_computation_status(self, job_uuid: str) \ return GetComputationStatusResponse(statuses, progresses) def get_job_error_info(self, job_uuid: str) -> GetJobErrorInfoResponse: + """計算のエラー情報をMPCサーバから取得する + + Parameters + ---------- + job_uuid: str + jobのID + + Returns + ------- + GetJobErrorInfoResponse + 計算のエラー情報 + """ # リクエストパラメータを設定 req = GetComputationRequest( job_uuid=job_uuid, @@ -400,6 +671,18 @@ def get_job_error_info(self, job_uuid: str) -> GetJobErrorInfoResponse: return GetJobErrorInfoResponse(job_error_info) def get_elapsed_time(self, job_uuid: str) -> GetElapsedTimeResponse: + """計算にかかった時間をMPCサーバから取得する + + Parameters + ---------- + job_uuid: str + jobのID + + Returns + ------- + GetElapsedTimeResponse + 計算にかかった時間 + """ # リクエストパラメータを設定 req = GetElapsedTimeRequest( job_uuid=job_uuid, @@ -416,6 +699,17 @@ def get_elapsed_time(self, job_uuid: str) -> GetElapsedTimeResponse: return GetElapsedTimeResponse(elapsed_time) def delete_share(self, data_ids: List[str]) -> None: + """送信したShareをMPCサーバから削除する + + Parameters + ---------- + data_ids: List[str] + 削除したいShareのIDリスト + + Returns + ------- + None + """ req = DeleteSharesRequest(dataIds=data_ids, token=self.__token) # 非同期にリクエスト送信 with ThreadPoolExecutor() as executor: @@ -425,6 +719,20 @@ def delete_share(self, data_ids: List[str]) -> None: def add_share_data_frame(self, base_data_id: str, add_data_id: str) \ -> AddShareDataFrameResponse: + """テーブル加算リクエストをMPCサーバに送信する + + Parameters + ---------- + base_data_id: str + 左項のテーブルのID + add_data_id: str + 右項のテーブルのID + + Returns + ------- + AddShareDataFrameResponse + テーブル加算の結果 + """ req = AddShareDataFrameRequest(base_data_id=base_data_id, add_data_id=add_data_id, token=self.__token) diff --git a/packages/client/libclient-py/quickmpc/share/random.py b/packages/client/libclient-py/quickmpc/share/random.py index 73b923086..5eeec8d00 100644 --- a/packages/client/libclient-py/quickmpc/share/random.py +++ b/packages/client/libclient-py/quickmpc/share/random.py @@ -23,20 +23,42 @@ def get_list(self, a, b, size: int) -> List[int]: @dataclass(frozen=True) class ChaCha20(RandomInterface): + """ChaCha20による乱数生成器 + + args + ---- + mx: int + 乱数最大値 + mn: int + 乱数最小値 + """ # 128bit符号付き整数最大,最小値 mx: ClassVar[int] = (1 << 128)-1 mn: ClassVar[int] = -(1 << 128) @methoddispatch() - def get(self, a, b): - raise ArgumentError( - "乱数の閾値はどちらもintもしくはdecimalでなければなりません." - f"a is {type(a)}, b is {type(b)}") + def get(self, *args, **kw): + """乱数を生成する + + overloadして使用される. + """ + raise ArgumentError("乱数の閾値はintもしくはdecimalでなければなりません.") @get.register(int) def __get_int(self, a: int, b: int) -> int: - # TRNGで [a,b) の乱数生成 + """TRNGで整数の乱数を生成する + + Parameters + ---------- + a, b: int + [a, b)の範囲で生成する + + Returns + ------- + int + 整数乱数 + """ self.__exception_check(a, b) interval_byte = self.__get_byte_size(b-a) byte_val: bytes = random(interval_byte) @@ -45,19 +67,49 @@ def __get_int(self, a: int, b: int) -> int: @get.register(Decimal) def __get_decimal(self, a: Decimal, b: Decimal) -> Decimal: + """TRNGで実数の乱数を生成する + + Parameters + ---------- + a, b: Decimal + [a, b)の範囲で生成する + + Returns + ------- + Decimal + 整数乱数 + """ # 256bit整数を取り出して[a,b]に正規化する self.__exception_check(a, b) val: int = self.get(self.mn, self.mx) return Decimal(val-self.mn)/(self.mx-self.mn)*(b-a)+a @methoddispatch() - def get_list(self, a, b, size: int): - raise ArgumentError( - "乱数の閾値はどちらもintもしくはdecimalでなければなりません." - f"a is {type(a)}, b is {type(b)}") + def get_list(self, *args, **kw): + """乱数配列を生成する + + overloadして使用される. + """ + raise ArgumentError("乱数の閾値はintもしくはdecimalでなければなりません.") @get_list.register(int) def __get_list_int(self, a: int, b: int, size: int) -> List[int]: + """CSPRNGで整数の乱数配列を生成する + + seedをTRNGで生成して配列サイズ分の乱数はCSPRNGで生成する. + + Parameters + ---------- + a, b: int + [a, b)の範囲で生成する + size: int + 配列サイズ + + Returns + ------- + List[int] + 整数乱数の配列 + """ # TRNGの32byteをseedとしてCSPRNGでsize分生成 byte_size: int = self.__get_byte_size(b-a) self.__exception_check(a, b) @@ -70,6 +122,20 @@ def __get_list_int(self, a: int, b: int, size: int) -> List[int]: @get_list.register(Decimal) def __get_list_decimal(self, a: Decimal, b: Decimal, size: int) \ -> List[Decimal]: + """CSPRNGで実数の乱数配列を生成する + + Parameters + ---------- + a, b: Decimal + [a, b)の範囲で生成する + size: int + 配列サイズ + + Returns + ------- + List[Decimal] + 実数乱数の配列 + """ # 128bit整数を取り出して[a,b]に正規化する self.__exception_check(a, b) valList: List[int] = self.get_list(self.mn, self.mx, size) @@ -77,13 +143,46 @@ def __get_list_decimal(self, a: Decimal, b: Decimal, size: int) \ for val in valList] def __get_byte_size(self, x: int) -> int: - # 整数の byte サイズを取得 + """整数のbyteサイズを取得する + + Parameters + ---------- + x: int + 整数 + + Returns + ------- + int + 入力整数のbyteサイズ + """ return max(math.ceil(math.log2(x))//8 + 1, 32) def __get_32byte(self) -> bytes: + """32bit乱数を/dev/randomから取得する + + Parameters + ---------- + + Returns + ------- + bytes + 乱数byte列 + """ return random() def __exception_check(self, a, b) -> None: + """乱数生成関数のvalidation checkをする + + Parameters + ---------- + a, b: any + 生成したい乱数の範囲[a,b) + + + Returns + ------- + None + """ if a >= b: raise ArgumentError( "乱数の下限は上限より小さい必要があります." diff --git a/packages/client/libclient-py/quickmpc/share/restore.py b/packages/client/libclient-py/quickmpc/share/restore.py index 3dec1f23c..a5fdb27e5 100644 --- a/packages/client/libclient-py/quickmpc/share/restore.py +++ b/packages/client/libclient-py/quickmpc/share/restore.py @@ -11,7 +11,21 @@ from quickmpc.utils import if_present -def get_meta(job_uuid: str, path: str): +def get_meta(job_uuid: str, path: str) -> int: + """結果データからmeta情報を取り出す + + Parameters + ---------- + job_uuid: str + 計算結果のID + path: str + 計算結果を保存したpath + + Returns + ------- + int + metaデータ + """ file_name = glob.glob(f"{path}/dim?-{job_uuid}-*")[0] with open(file_name, 'r') as f: reader = csv.reader(f) @@ -21,6 +35,20 @@ def get_meta(job_uuid: str, path: str): def get_result(job_uuid: str, path: str, party: int): + """結果データからmeta情報を取り出す + + Parameters + ---------- + job_uuid: str + 計算結果のID + path: str + 計算結果を保存したpath + + Yields + ------ + List[str] + テーブルデータの行 + """ for file_name in natsorted(glob.glob(f"{path}-{job_uuid}-{party}-*")): with open(file_name, 'r') as f: reader = csv.reader(f) @@ -32,6 +60,22 @@ def get_result(job_uuid: str, path: str, party: int): def restore(job_uuid: str, path: str, party_size: int) -> Any: + """ファイルに保存された結果データを復元する + + Parameters + ---------- + job_uuid: str + 計算結果のID + path: str + 計算結果を保存したpath + party_size: int + MPCのパーティ数 + + Returns + ------- + Any + 復元した計算結果 + """ column_number = get_meta(job_uuid, path) schema: Any = [None]*column_number diff --git a/packages/client/libclient-py/quickmpc/share/share.py b/packages/client/libclient-py/quickmpc/share/share.py index 8e4bb7065..3bcd4d254 100644 --- a/packages/client/libclient-py/quickmpc/share/share.py +++ b/packages/client/libclient-py/quickmpc/share/share.py @@ -18,49 +18,105 @@ @dataclass(frozen=True) class Share: + """Share関連の処理をまとめているクラス + + Attributes + ---------- + _Share__share_random_range: Tuple[Decimal, Decimal] + share乱数の範囲 + """ __share_random_range: ClassVar[Tuple[Decimal, Decimal]] =\ (Decimal(-(1 << 64)), Decimal(1 << 64)) @methoddispatch(is_static_method=True) @staticmethod - def __to_str(_): - logger.error("Invalid argument on stringfy.") + def __to_str(*args, **kw): + """文字列に変換する + + overloadして使用される. + """ raise ArgumentError("不正な引数が与えられています.") @__to_str.register(Decimal) @staticmethod def __decimal_to_str(val: Decimal) -> str: + """decimalを文字列に変換する + + Parameters + ---------- + val: Decimal + 変換する値 + + Returns + ------- + str + 変換された文字列 + """ # InfinityをCCで読み込めるinfに変換 return 'inf' if Decimal.is_infinite(val) else str(val) @__to_str.register(int) @staticmethod def __int_to_str(val: int) -> str: + """intを文字列に変換する + + Parameters + ---------- + val: int + 変換する値 + + Returns + ------- + str + 変換された文字列 + """ return str(val) @methoddispatch(is_static_method=True) @staticmethod - def sharize(_, __): - logger.error("Invalid argument on sharize.") + def sharize(*args, **kw): + """sharizeする + + overloadして使用される. + """ raise ArgumentError("不正な引数が与えられています.") @methoddispatch(is_static_method=True) @staticmethod - def recons(_): - logger.error("Invalid argument on recons.") + def recons(*args, **kw): + """Shareを復元する + + overloadして使用される. + """ raise ArgumentError("不正な引数が与えられています.") @methoddispatch(is_static_method=True) @staticmethod - def convert_type(_, __): - logger.error("Invalid argument on convert_type.") + def convert_type(*args, **kw): + """値を変換する + + overloadして使用される. + """ raise ArgumentError("不正な引数が与えられています.") @sharize.register(int) @sharize.register(float) @staticmethod def __sharize_scalar(secrets: float, party_size: int = 3) -> List[str]: - """ スカラ値のシェア化 """ + """スカラ値をシェア化する + + Parameters + ---------- + secrets: float + スカラ値 + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[str] + シェア + """ rnd: RandomInterface = ChaCha20() shares: List[int] = rnd.get_list( *Share.__share_random_range, party_size) @@ -73,7 +129,20 @@ def __sharize_scalar(secrets: float, party_size: int = 3) -> List[str]: def __sharize_1dimension_float(secrets: List[Union[float, Decimal]], party_size: int = 3) \ -> List[List[str]]: - """ 1次元リストのシェア化 """ + """実数の1次元リストをシェア化する + + Parameters + ---------- + secrets: List[Union[float, Decimal]] + 1次元リスト + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[List[str]] + シェアのリスト + """ rnd: RandomInterface = ChaCha20() secrets_size: int = len(secrets) shares: np.ndarray = np.array([ @@ -90,13 +159,40 @@ def __sharize_1dimension_float(secrets: List[Union[float, Decimal]], def __sharize_1dimension_decimal(secrets: List[Decimal], party_size: int = 3) \ -> List[List[str]]: + """1次元リストをシェア化する + + Parameters + ---------- + secrets: List[Decimal] + 1次元リスト + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[List[str]] + シェアのリスト + """ return Share.__sharize_1dimension_float(secrets, party_size) @sharize.register((Dim1, int)) @staticmethod def __sharize_1dimension_int(secrets: List[int], party_size: int = 3) \ -> List[List[str]]: - """ 1次元リストのシェア化 """ + """整数の1次元リストをシェア化する + + Parameters + ---------- + secrets: List[int] + 1次元リスト + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[List[str]] + シェアのリスト + """ rnd: RandomInterface = ChaCha20() secrets_size: int = len(secrets) max_val = (max(secrets)+1) * 2 @@ -113,7 +209,20 @@ def __sharize_1dimension_int(secrets: List[int], party_size: int = 3) \ @staticmethod def __sharize_2dimension(secrets: List[List[Union[float, int]]], party_size: int = 3) -> List[List[List[str]]]: - """ 2次元リストのシェア化 """ + """2次元リストをシェア化する + + Parameters + ---------- + secrets: List[int] + 1次元リスト + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[List[List[str]]] + シェアのリスト + """ transposed: List[Union[List[int], List[float]]] \ = np.array(secrets, dtype=object).transpose().tolist() dst: List[List[List[str]]] = [ @@ -126,7 +235,20 @@ def __sharize_2dimension(secrets: List[List[Union[float, int]]], @sharize.register(dict) @staticmethod def __sharize_dict(secrets: dict, party_size: int = 3) -> List[dict]: - """ 辞書型のシェア化 """ + """辞書型をシェア化する + + Parameters + ---------- + secrets: List[int] + 1次元リスト + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[dict] + シェアのリスト + """ shares_str: List[dict] = [dict() for _ in range(party_size)] for key, val in secrets.items(): for i, share_val in enumerate(Share.sharize(val, party_size)): @@ -137,7 +259,20 @@ def __sharize_dict(secrets: dict, party_size: int = 3) -> List[dict]: @staticmethod def __sharize_dictlist(secrets: dict, party_size: int = 3) \ -> List[List[dict]]: - """ 辞書型配列のシェア化 """ + """辞書型配列をシェア化する + + Parameters + ---------- + secrets: List[int] + 1次元リスト + party_size: int, defulat=3 + MPCのパーティ数 + + Returns + ------- + List[List[dict]] + シェアのリスト + """ shares_str: List[List[dict]] = [[] for _ in range(party_size)] for secret_dict in secrets: share_dict: List[dict] = Share.sharize(secret_dict, party_size) @@ -148,14 +283,36 @@ def __sharize_dictlist(secrets: dict, party_size: int = 3) \ @recons.register(Dim1) @staticmethod def __recons_list1(shares: List[Union[int, Decimal]]): - """ 1次元リストのシェアを復元 """ + """1次元リストのシェアを復元 + + Parameters + ---------- + shares: List[Union[int, Decimal]] + シェア + + Returns + ------- + Union[int, Decimal] + 復元した値 + """ return sum(shares) @recons.register(Dim2) @recons.register(Dim3) @staticmethod def __recons_list(shares: List[List[Union[int, Decimal]]]) -> List: - """ リストのシェアを復元 """ + """リストのシェアを復元 + + Parameters + ---------- + shares: List[List[Union[int, Decimal]]] + シェア + + Returns + ------- + List + 復元した値 + """ secrets: List = [ Share.recons([shares_pi[i] for shares_pi in shares]) for i in range(len(shares[0])) @@ -165,7 +322,18 @@ def __recons_list(shares: List[List[Union[int, Decimal]]]) -> List: @recons.register(DictList) @staticmethod def __recons_dictlist(shares: List[dict]) -> dict: - """ 辞書型を復元 """ + """辞書型のシェアを復元 + + Parameters + ---------- + shares: List[dict] + シェア + + Returns + ------- + dict + 復元した値 + """ secrets: dict = dict() for key in shares[0].keys(): val = [] @@ -177,7 +345,18 @@ def __recons_dictlist(shares: List[dict]) -> dict: @recons.register(DictList2) @staticmethod def __recons_dictlist2(shares: List[List[dict]]) -> list: - """ 辞書型配列を復元 """ + """辞書型配列のシェアを復元 + + Parameters + ---------- + shares: List[List[dict]] + シェア + + Returns + ------- + List[dict] + 復元した値 + """ secrets: list = list() for i in range(len(shares[0])): val = [] @@ -189,7 +368,18 @@ def __recons_dictlist2(shares: List[List[dict]]) -> list: @staticmethod def get_pre_convert_func( schema: Optional[Schema]) -> Callable[[str], Any]: - """ スキーマに合った変換関数を返す """ + """スキーマに合った変換関数を返す + + Parameters + ---------- + schema: Optional[Schema]) + スキーマ + + Returns + ------- + Callable[[str], Any]: + 変換関数 + """ if schema is None: return Decimal type = schema.type @@ -205,6 +395,18 @@ def get_pre_convert_func( @staticmethod def convert_int_to_str(x: int): + """intをstrに変換する + + Parameters + ---------- + x: int + 整数 + + Returns + ------- + str + 変換した文字列 + """ bytes_repr: bytes = x.to_bytes((x.bit_length() + 7) // 8, byteorder='big') str_repr: str = bytes_repr.decode('utf-8') @@ -213,7 +415,18 @@ def convert_int_to_str(x: int): @staticmethod def get_convert_func( schema: Optional[Schema]) -> Callable[[Any], Any]: - """ スキーマに合った変換関数を返す """ + """スキーマに合った変換関数を返す + + Parameters + ---------- + schema: Optional[Schema]) + スキーマ + + Returns + ------- + Callable[[str], Any]: + 変換関数 + """ if schema is None: return float type = schema.type @@ -228,8 +441,22 @@ def get_convert_func( @convert_type.register(str) @staticmethod - def __pre_convert_type_str( - value: str, schema: Optional[Schema] = None) -> list: + def __pre_convert_type_str(value: str, schema: Optional[Schema] = None) \ + -> Union[int, float, Decimal]: + """文字列をスキーマタグに沿って数値に変換する + + Parameters + ---------- + value: str + 変換する文字列 + schema: Optional[Schema], default=None + スキーマ + + Returns + ------- + Union[int, float, Decimal]: + 変換した数値 + """ func = Share.get_pre_convert_func(schema) return func(value) @@ -238,6 +465,20 @@ def __pre_convert_type_str( def __convert_type_list( values: List[Any], schema: Optional[Sequence[Optional[Schema]]] = None) -> list: + """リストをスキーマタグに沿って数値に変換する + + Parameters + ---------- + value: List[Any] + 変換する文字列 + schema: Optional[Sequence[Optional[Schema]]], default=None + スキーマ + + Returns + ------- + list + 変換した数値 + """ if schema is None: schema = [None] * len(values) return [Share.convert_type(x, sch) @@ -249,6 +490,20 @@ def __convert_type_list( def __convert_type_elem( value: Union[Decimal, int], schema: Optional[Schema] = None) -> Union[float, str]: + """数値をスキーマタグに沿って数値に変換する + + Parameters + ---------- + value: Union[Decimal, int] + 変換する文字列 + schema: Optional[Schema], default=None + スキーマ + + Returns + ------- + Union[float, int] + 変換した数値 + """ func = Share.get_convert_func(schema) return func(value) @@ -257,4 +512,18 @@ def __convert_type_elem( def __convert_type_table( table: List[List], schema: Optional[List[Schema]] = None) -> list: + """テーブルデータをスキーマタグに沿って数値に変換する + + Parameters + ---------- + value: Union[Decimal, int] + 変換する文字列 + schema: Optional[List[Schema]], default=None + スキーマ + + Returns + ------- + List + 変換した数値 + """ return [Share.convert_type(row, schema) for row in table] diff --git a/packages/client/libclient-py/quickmpc/utils/if_present.py b/packages/client/libclient-py/quickmpc/utils/if_present.py index 8c31cb382..b6d3d459c 100644 --- a/packages/client/libclient-py/quickmpc/utils/if_present.py +++ b/packages/client/libclient-py/quickmpc/utils/if_present.py @@ -3,8 +3,24 @@ def if_present(optional: Optional[Any], func: Callable[[Any], Any], - *arg: Any + *args: Any ) -> Optional[Any]: + """optional変数を評価して関数を実行する + + Parameters + ---------- + optional: Optional[Any] + 何らかのoptional変数 + func: Callable[[Any], Any] + 実行する関数 + *args: Any + 実行する関数の引数列 + + Returns + ------- + Optional[Any] + 関数の実行結果 + """ if optional is None: return None - return func(optional, *arg) + return func(optional, *args) diff --git a/packages/client/libclient-py/quickmpc/utils/make_pieces.py b/packages/client/libclient-py/quickmpc/utils/make_pieces.py index 10dd17769..44cf19242 100644 --- a/packages/client/libclient-py/quickmpc/utils/make_pieces.py +++ b/packages/client/libclient-py/quickmpc/utils/make_pieces.py @@ -9,30 +9,81 @@ @dataclass(frozen=True) class MakePiece: + """値をいくつかのpieceに分割するクラス + """ @methoddispatch(is_static_method=True) @staticmethod - def __get_byte(p): ... + def __get_byte(*args, **kw): + """型のbyte数を取得する + + overloadして使用される + """ + raise NotImplementedError("not implemented.") @__get_byte.register(float) @__get_byte.register(int) @staticmethod def __get_byte_number(f): + """数値型のbyte数を取得する + + Parameters + ---------- + f + 数値 + + Returns + ------- + int + byte数 + """ # 扱う数値は64bitなので8byte return 8 @__get_byte.register(str) @staticmethod def __get_byte_str(s: str): + """文字列のbyte数を取得する + + Parameters + ---------- + s: str + 文字列 + + Returns + ------- + int + byte数 + """ return len(s) @methoddispatch(is_static_method=True) @staticmethod - def make_pieces(_, __): + def make_pieces(*args, **kw): + """値をいpieceに分割する + + overloadして使用される + """ raise ArgumentError("不正な引数が与えられています.") @staticmethod def check_max_size(max_size: int): + """piece_sizeのvalidator + + Parameters + ---------- + max_size: int + 分割サイズ + + Returns + ------- + None + + Raises + ------ + RuntimeError + 分割サイズがgrpcの要件を満たさないとき + """ # NOTE: grpcの送受信データサイズ上限:4MB lower_limit_size: int = 1 upper_limit_size: int = 1_000_000 @@ -50,6 +101,20 @@ def check_max_size(max_size: int): @make_pieces.register(Dim1) @staticmethod def __make_pieces_1d(src: List[str], max_size: int) -> List[List[str]]: + """1次元配列をpiece分割する + + Parameters + ---------- + src: List[str] + 分割する配列 + max_size: int + 分割1pieceごとの最大byte数 + + Returns + ------- + List[List[str]] + 分割した配列 + """ MakePiece.check_max_size(max_size) cur_size = 0 index = 0 @@ -76,6 +141,20 @@ def __make_pieces_1d(src: List[str], max_size: int) -> List[List[str]]: def __make_pieces_2d(src: List[List[str]], max_size: int ) -> List[List[List[str]]]: + """2次元配列をpiece分割する + + Parameters + ---------- + src: List[List[str]] + 分割する配列 + max_size: int + 分割1pieceごとの最大byte数 + + Returns + ------- + List[List[List[str]]] + 分割した配列 + """ MakePiece.check_max_size(max_size) cur_size = 0 index = 0 @@ -102,6 +181,20 @@ def __make_pieces_2d(src: List[List[str]], @make_pieces.register(str) @staticmethod def __make_pieces_str(src: str, max_size: int) -> List[str]: + """文字列をpiece分割する + + Parameters + ---------- + src: str + 分割する文字列 + max_size: int + 分割1pieceごとの最大byte数 + + Returns + ------- + [List[str] + 分割した配列 + """ MakePiece.check_max_size(max_size) dst: List[str] = [] current: str = "" diff --git a/packages/client/libclient-py/quickmpc/utils/overload_tools.py b/packages/client/libclient-py/quickmpc/utils/overload_tools.py index 7339bf936..51d80b7b2 100644 --- a/packages/client/libclient-py/quickmpc/utils/overload_tools.py +++ b/packages/client/libclient-py/quickmpc/utils/overload_tools.py @@ -6,26 +6,43 @@ class Dim1: + """1次元配列であることを示すクラス""" ... class Dim2: + """2次元配列であることを示すクラス""" ... class Dim3: + """3次元配列であることを示すクラス""" ... class DictList: + """1次元辞書型配列であることを示すクラス""" ... class DictList2: + """2次元辞書型配列であることを示すクラス""" ... -def d(lst): +def d(lst) -> int: + """次元数を取得する + + Parameters + ---------- + lst: any + 次元数を取得する変数 + + Returns + ------- + int + 次元数 + """ if not isinstance(lst, list): return 0 if not lst: @@ -34,6 +51,18 @@ def d(lst): def find_element_type(lst) -> type: + """配列の要素の型を取得する + + Parameters + ---------- + lst: any + 要素の型を取得する変数 + + Returns + ------- + type + 要素の型 + """ if not isinstance(lst, list): return lst.__class__ if not lst: @@ -42,6 +71,18 @@ def find_element_type(lst) -> type: def _get_dim_class(lst: list): + """次元数を加味したクラス分類を取得する + + Parameters + ---------- + lst: any + クラス分類を取得する変数 + + Returns + ------- + Any + クラス分類 + """ dim: int = d(lst) if dim == 1: if len(lst) and isinstance(lst[0], dict): @@ -57,6 +98,20 @@ def _get_dim_class(lst: list): def _convert_list_type(registry, type): + """次元数を加味したクラス分類を取得する + + Parameters + ---------- + registry: dict + 登録されている型集合 + type: any + 分類するクラス + + Returns + ------- + Any + クラス分類 + """ if type.__class__ is not list: return type for Dim in (Dim1, Dim2, Dim3, DictList): @@ -66,6 +121,18 @@ def _convert_list_type(registry, type): def methoddispatch(is_static_method: bool = False): + """overloadするための関数 + + Parameters + ---------- + is_static_method: bool + staticmethodかどうか + + Returns + ------- + Callable + dispatch関数 + """ def _dimdispatch(func): registry: dict = {} default_function = func diff --git a/packages/client/libclient-py/setup.cfg b/packages/client/libclient-py/setup.cfg deleted file mode 100644 index 15ae61518..000000000 --- a/packages/client/libclient-py/setup.cfg +++ /dev/null @@ -1,19 +0,0 @@ -[metadata] -name = quickmpc -author = Acompany -license=Apache 2.0 Licences -license_file=LICENSE - -[options] -setup_requires=setuptools>=45; setuptools_scm[toml]>=6.2 -python_requires=>=3.7 -install_requires = - numpy - grpcio-tools - grpcio-status - pynacl - tqdm - natsort - pandas -include_package_data = True -packages = find: diff --git a/packages/client/libclient-py/setup.py b/packages/client/libclient-py/setup.py deleted file mode 100644 index b024da80e..000000000 --- a/packages/client/libclient-py/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -from setuptools import setup - - -setup() diff --git a/scripts/container_test/Pipfile b/scripts/container_test/Pipfile index 39a013ef8..de3b78ff1 100644 --- a/scripts/container_test/Pipfile +++ b/scripts/container_test/Pipfile @@ -4,8 +4,6 @@ verify_ssl = true name = "pypi" [packages] -pytest = "*" -quickmpc = {path = "./../../packages/client/libclient-py"} [dev-packages] diff --git a/scripts/container_test/Pipfile.lock b/scripts/container_test/Pipfile.lock deleted file mode 100644 index 599aaa080..000000000 --- a/scripts/container_test/Pipfile.lock +++ /dev/null @@ -1,102 +0,0 @@ -{ - "_meta": { - "hash": { - "sha256": "ad8090164c262479972d022fd437f633973a2687fb75d9996397acd5647af680" - }, - "pipfile-spec": 6, - "requires": { - "python_full_version": "3.7.10", - "python_version": "3.7" - }, - "sources": [ - { - "name": "pypi", - "url": "https://pypi.org/simple", - "verify_ssl": true - } - ] - }, - "default": { - "attrs": { - "hashes": [ - "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836", - "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99" - ], - "markers": "python_version >= '3.6'", - "version": "==22.2.0" - }, - "exceptiongroup": { - "hashes": [ - "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e", - "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785" - ], - "markers": "python_version < '3.11'", - "version": "==1.1.1" - }, - "importlib-metadata": { - "hashes": [ - "sha256:43ce9281e097583d758c2c708c4376371261a02c34682491a8e98352365aad20", - "sha256:ff80f3b5394912eb1b108fcfd444dc78b7f1f3e16b16188054bd01cb9cb86f09" - ], - "markers": "python_version < '3.8'", - "version": "==6.1.0" - }, - "iniconfig": { - "hashes": [ - "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", - "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374" - ], - "markers": "python_version >= '3.7'", - "version": "==2.0.0" - }, - "packaging": { - "hashes": [ - "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2", - "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97" - ], - "markers": "python_version >= '3.7'", - "version": "==23.0" - }, - "pluggy": { - "hashes": [ - "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159", - "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3" - ], - "markers": "python_version >= '3.6'", - "version": "==1.0.0" - }, - "pytest": { - "hashes": [ - "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e", - "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4" - ], - "index": "pypi", - "version": "==7.2.2" - }, - "tomli": { - "hashes": [ - "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", - "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f" - ], - "markers": "python_version < '3.11'", - "version": "==2.0.1" - }, - "typing-extensions": { - "hashes": [ - "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb", - "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4" - ], - "markers": "python_version < '3.8'", - "version": "==4.5.0" - }, - "zipp": { - "hashes": [ - "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b", - "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556" - ], - "markers": "python_version >= '3.7'", - "version": "==3.15.0" - } - }, - "develop": {} -} diff --git a/scripts/container_test/run_medium_test.sh b/scripts/container_test/run_medium_test.sh index 0db48806b..6dd2f03d5 100755 --- a/scripts/container_test/run_medium_test.sh +++ b/scripts/container_test/run_medium_test.sh @@ -9,6 +9,7 @@ if [[ ! $PIP_LIST =~ "pipenv" ]]; then fi # 環境を構築してTestを走らせる -pipenv sync -pipenv install --skip-lock ../../packages/client/libclient-py +pipenv run pip install -U pip +pipenv run pip install pytest +pipenv run pip install ../../packages/client/libclient-py pipenv run pytest ./tests -s -v -log-cli-level=DEBUG