diff --git a/packages/server/computation_container/math/math.cpp b/packages/server/computation_container/math/math.cpp index acc7bcd3c..57256b824 100644 --- a/packages/server/computation_container/math/math.cpp +++ b/packages/server/computation_container/math/math.cpp @@ -5,137 +5,71 @@ namespace qmpc::Math { Share sum(const std::vector &v) { - Share ret; - for (const auto &a : v) - { - ret += a; - } - return ret; + Share e(FixedPoint(0)); + return std::accumulate(v.begin(), v.end(), e); } + Share smean(const std::vector &v) { - // Share avg(FixedPoint(0.0)); - // for (int i = 0; i < size; i++) - // { - // avg = avg + v[i]; - // } - Share ret{}; - for (auto a : v) - { - ret += a; - } - int size = std::size(v); - ret /= FixedPoint(size); - - return ret; + assert(std::size(v)); + return sum(v) / FixedPoint(std::size(v)); } -Share variance(std::vector &v) +std::vector deviation(std::vector v) { - Share avg; - - avg = smean(v); - std::vector var; - for (auto &a : v) + Share avg = smean(v); + for (auto &s : v) { - var.emplace_back(a - avg); + s -= avg; } - auto varVec = var * var; - Share ret{}; - for (auto &a : varVec) - { - ret += a; - } - // FPのresolutionの制約により、FPだと1000より大きい数で割れないためdoubleで割っている - // double var_d = var.getDoubleVal(); - int size = std::size(v); - ret /= FixedPoint(size); + return v; +} - return ret; +Share variance(const std::vector &v) +{ + auto dev = deviation(v); + auto var = dev * dev; + return smean(var); } -FixedPoint stdev(std::vector &v) +FixedPoint stdev(const std::vector &v) { - Share var; - var = variance(v); + Share var = variance(v); + FixedPoint var_val = open_and_recons(var); - FixedPoint stdev = open_and_recons(var); - auto value = stdev.getDoubleVal(); + auto value = var_val.getDoubleVal(); if (value < 0) { value = 0; } - auto r = sqrt(value); - FixedPoint ret{r}; - return ret; + return FixedPoint(sqrt(value)); } -Share correl(std::vector &x, std::vector &y) + +Share covariance(const std::vector &x, const std::vector &y) { - int sizex = (int)x.size(); - int sizey = (int)y.size(); + auto devX = deviation(x); + auto devY = deviation(y); + auto devXY = devX * devY; + return smean(devXY); +} - if (sizex != sizey) +Share correl(const std::vector &x, const std::vector &y) +{ + if (x.size() != y.size()) { qmpc::Log::throw_with_trace(std::runtime_error("input Size is not Equal")); } - Share aveX = smean(x); - Share aveY = smean(y); FixedPoint stdeX = stdev(x); FixedPoint stdeY = stdev(y); + // 0除算 if (stdeX == FixedPoint(0) || stdeY == FixedPoint(0)) { QMPC_LOG_ERROR("correl returns 0 when stdev is 0"); return Share(FixedPoint(0)); } - int n = sizex; - std::vector tmpX; - tmpX.reserve(n); - std::vector tmpY; - tmpY.reserve(n); - for (int i = 0; i < n; ++i) - { - tmpX.emplace_back(x[i] - aveX); - tmpY.emplace_back(y[i] - aveY); - } - auto tmpVec = tmpX * tmpY; - Share ret{}; - for (auto &r : tmpVec) - { - ret += r; - } - ret /= stdeX; - ret /= stdeY; - ret /= FixedPoint(n); - return ret; -} - -Share exp(const Share &x) -{ - // Nはマクローリン展開時の項数 - // 1+x+x^2 ... x^N-1 - constexpr int N = 100; - auto *conf = Config::getInstance(); - Share ret; - if (conf->sp_id == conf->party_id) - { - ret += 1; - } - std::vector px(N); - std::vector k(N); - px[0] = ret; - k[0] = 1; - for (int i = 1; i < N; ++i) - { - k[i] = k[i - 1] * i; - px[i] = px[i - 1] * x; - } - for (int i = 1; i < N; ++i) - { - px[i] /= k[i]; - } - return std::accumulate(px.begin(), px.end(), Share{0}); + return covariance(x, y) / (stdeX * stdeY); } } // namespace qmpc::Math \ No newline at end of file diff --git a/packages/server/computation_container/math/math.hpp b/packages/server/computation_container/math/math.hpp index d33b0da9c..4592fc048 100644 --- a/packages/server/computation_container/math/math.hpp +++ b/packages/server/computation_container/math/math.hpp @@ -1,8 +1,5 @@ #pragma once -#include -#include -#include -#include +#include #include #include "share/share.hpp" @@ -12,13 +9,9 @@ namespace qmpc::Math using Share = qmpc::Share::Share; Share sum(const std::vector &v); Share smean(const std::vector &v); -Share variance(std::vector &v); -// 標準偏差 -FixedPoint stdev(std::vector &v); -// 相関係数 -Share correl(std::vector &x, std::vector &y); -Share exp(const Share &x); -Share sigmoid(const Share &x, const FixedPoint &a = 1); -Share open_sigmoid(const Share &x_s, const FixedPoint &a = 1); -Share open_sigmoid_vector(const std::vector &v_s, const FixedPoint &a = 1); +std::vector deviation(std::vector v); +Share variance(const std::vector &v); +FixedPoint stdev(const std::vector &v); +Share covariance(const std::vector &x, const std::vector &y); +Share correl(const std::vector &x, const std::vector &y); } // namespace qmpc::Math \ No newline at end of file diff --git a/packages/server/computation_container/test/integration_test/math_test.hpp b/packages/server/computation_container/test/integration_test/math_test.hpp index b60edf22e..7e4787591 100644 --- a/packages/server/computation_container/test/integration_test/math_test.hpp +++ b/packages/server/computation_container/test/integration_test/math_test.hpp @@ -163,7 +163,7 @@ TEST(MathTest, Correl) TEST(MathTest, Correl_0div) { std::vector x = {FixedPoint("2.0"), FixedPoint("2.0"), FixedPoint("2.0")}; - std::vector y = {FixedPoint("10.0"), FixedPoint("10.0"), FixedPoint("10.0")}; + std::vector y = {FixedPoint("9.0"), FixedPoint("10.0"), FixedPoint("11.0")}; Share correl_rec = qmpc::Math::correl(x, y); FixedPoint target = open_and_recons(correl_rec); @@ -225,30 +225,6 @@ TEST(MathTest, Correl_large) QMPC_LOG_INFO(correl_rec.getStrVal()); } -TEST(MathTest, ExpTest) -{ - Config *conf = Config::getInstance(); - int n_parties = conf->n_parties; - - // x = n_parties; - Share x{1}; - auto start = std::chrono::system_clock::now(); - - auto exp_n = qmpc::Math::exp(x); - double expect = std::exp(n_parties); - auto exp_n_rec = open_and_recons(exp_n); - auto end = std::chrono::system_clock::now(); - auto dur = end - start; - - // 計算に要した時間をミリ秒(1/1000秒)に変換して表示 - auto msec = std::chrono::duration_cast(dur).count(); - - QMPC_LOG_INFO("share exp time is {}", msec); - QMPC_LOG_INFO("share exp_n is {}", exp_n_rec); - QMPC_LOG_INFO("expect exp_n is {}", expect); - - EXPECT_NEAR(expect, exp_n_rec.getDoubleVal(), 0.001); -} TEST(MathTest, correlVecExceptionTest) { constexpr int N = 28000; diff --git a/scripts/libclient/src/tests/test_sum.py b/scripts/libclient/src/tests/test_sum.py index c678d1088..ccf8f817d 100644 --- a/scripts/libclient/src/tests/test_sum.py +++ b/scripts/libclient/src/tests/test_sum.py @@ -22,9 +22,9 @@ pd.DataFrame([2.0*10**18])), # small data case - (data_frame([[10**-8], [10**-8]], columns=["s1"]), + (data_frame([[10**-7], [10**-7]], columns=["s1"]), [1], - pd.DataFrame([2.0*10**-8])), + pd.DataFrame([2.0*10**-7])), # duplicated src case (data_frame([[1, 2, 3], [4, 5, 6]], columns=["s1", "s2", "s3"]),