diff --git a/.gitignore b/.gitignore index e833569..403e4e1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__/ *.whl external/ *.so +.vscode/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index cb63d2f..60a0aab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,22 +1,27 @@ cmake_minimum_required(VERSION 3.13.4) project(_pyapsi) +find_package(Threads REQUIRED) +find_package(SEAL 4.1 REQUIRED) +find_package(APSI 0.11.0 REQUIRED) +#find_package(jsoncpp REQUIRED) -include(FetchContent) +# include(FetchContent) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) +find_package(pybind11 REQUIRED) +#FetchContent_Declare( +# pybind11 +# GIT_REPOSITORY https://github.com/pybind/pybind11.git +# GIT_TAG v2.9.2 +#) +# FetchContent_MakeAvailable(pybind11) -FetchContent_Declare( - pybind11 - GIT_REPOSITORY https://github.com/pybind/pybind11.git - GIT_TAG v2.9.2 -) -FetchContent_MakeAvailable(pybind11) +# add_subdirectory(external/apsi/) +set(MAIN_SOURCES src/sender.cpp src/common_utils.cpp src/csv_reader.cpp src/main.cpp) +set(MAIN_HEADERS src/sender.h src/common_utils.h src/csv_reader.h src/base_clp.h ) -add_subdirectory(external/apsi/) +pybind11_add_module(_pyapsi ${MAIN_SOURCES} ${MAIN_HEADERS}) -pybind11_add_module(_pyapsi src/main.cpp) - -target_link_libraries(_pyapsi PRIVATE pybind11::module apsi) +target_link_libraries(_pyapsi PRIVATE pybind11::module APSI::apsi SEAL::seal) target_compile_definitions(_pyapsi PRIVATE) diff --git a/apsi/servers.py b/apsi/servers.py index 917e7a7..cb357b7 100644 --- a/apsi/servers.py +++ b/apsi/servers.py @@ -31,6 +31,17 @@ def load_db(self, db_file_path: str) -> None: self._load_db(db_file_path) self.db_initialized = True + + def load_csv_db(self,csv_db_file_path:str, params_json:str, + nonce_byte_count: int = 16, + compressed: bool = False + ) -> None: + """Load a database from csv file.""" + p = Path(csv_db_file_path) + if not p.exists(): + raise FileNotFoundError(f"DB file does not exist: {p}") + self._load_csv_db(csv_db_file_path, params_json, nonce_byte_count, compressed) + self.db_initialized = True def handle_oprf_request(self, oprf_request: bytes) -> bytes: """Handle an initial APSI Client OPRF request. @@ -109,9 +120,7 @@ def add_items(self, items_with_label: Iterable[Tuple[str, str]]) -> None: value is the label. """ self._requires_db() - # TODO: Expose batch add in C++ PyAPSI; and add length checks accordingly - for item, label in items_with_label: - self.add_item(item=item, label=label) + self._add_labeled_items(items_with_label) class UnlabeledServer(_BaseServer): diff --git a/src/base_clp.h b/src/base_clp.h new file mode 100644 index 0000000..e939015 --- /dev/null +++ b/src/base_clp.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +// STD +#include +#include + +// TCLAP +#ifdef _MSC_VER +#pragma warning(push, 0) +#endif +#include +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// APSI +#include + + +/** +Command line processor based on TCLAP. This is a base class that contains common arguments for both +parties. +*/ +class BaseCLP : public TCLAP::CmdLine { +public: + BaseCLP(const std::string &description, const std::string &version) + : TCLAP::CmdLine(description, /* delim */ ' ', version) + { + std::vector log_levels = { "all", "debug", "info", "warning", "error", "off" }; + log_level_constraint_ = std::make_unique>(log_levels); + log_level_arg_ = std::make_unique>( + "l", + "logLevel", + "One of \"all\", \"debug\", \"info\" (default), \"warning\", \"error\", \"off\"", + false, + "info", + log_level_constraint_.get(), + *this); + } + + virtual ~BaseCLP() + {} + + /** + Add additional arguments to the Command Line Processor. + */ + virtual void add_args() = 0; + + /** + Get the value of the additional arguments. + */ + virtual void get_args() = 0; + + bool parse_args(int argc, char **argv) + { + TCLAP::ValueArg threads_arg( + "t", + "threads", + "Number of threads to use", + /* req */ false, + /* value */ 0, + /* type desc */ "unsigned integer"); + add(threads_arg); + + TCLAP::ValueArg logfile_arg( + "f", "logFile", "Log file path", false, "", "file path"); + add(logfile_arg); + + TCLAP::SwitchArg silent_arg("s", "silent", "Do not write output to console", false); + add(silent_arg); + + // No need to add log_level_arg_, already added in constructor + + // Additional arguments + add_args(); + + try { + parse(argc, argv); + + silent_ = silent_arg.getValue(); + log_file_ = logfile_arg.getValue(); + threads_ = threads_arg.getValue(); + log_level_ = log_level_arg_->getValue(); + + apsi::Log::SetConsoleDisabled(silent_); + apsi::Log::SetLogFile(log_file_); + apsi::Log::SetLogLevel(log_level_); + + get_args(); + } catch (...) { + return false; + } + + return true; + } + + std::size_t threads() const + { + return threads_; + } + + const std::string &log_level() const + { + return log_level_; + } + + const std::string &log_file() const + { + return log_file_; + } + + bool silent() const + { + return silent_; + } + +private: + // Parameters from command line + std::size_t threads_; + std::string log_level_; + std::string log_file_; + bool silent_; + + // Parameters with constraints + std::unique_ptr> log_level_arg_; + std::unique_ptr> log_level_constraint_; +}; diff --git a/src/common_utils.cpp b/src/common_utils.cpp new file mode 100644 index 0000000..d8a36c9 --- /dev/null +++ b/src/common_utils.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "common_utils.h" + +// STD +#include +#include +#if defined(_MSC_VER) +#include +#endif +#if defined(__GNUC__) && (__GNUC__ < 8) && !defined(__clang__) +#include +#else +#include +#endif + +// APSI +#include +#include +#include +#include "base_clp.h" + +using namespace std; +#if defined(__GNUC__) && (__GNUC__ < 8) && !defined(__clang__) +namespace fs = std::experimental::filesystem; +#else +namespace fs = std::filesystem; +#endif +using namespace seal; +using namespace apsi; +using namespace apsi::util; + +/** +This only turns on showing colors for Windows. +*/ +void prepare_console() +{ +#ifndef _MSC_VER + return; // Nothing to do on Linux. +#else + HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (hConsole == INVALID_HANDLE_VALUE) + return; + + DWORD dwMode = 0; + if (!GetConsoleMode(hConsole, &dwMode)) + return; + + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + SetConsoleMode(hConsole, dwMode); +#endif +} + +vector generate_timespan_report( + const vector ×pans, int max_name_length) +{ + vector report; + + for (const auto ×pan : timespans) { + stringstream ss; + ss << setw(max_name_length) << left << timespan.event_name << ": " << setw(5) << right + << timespan.event_count << " instances. "; + if (timespan.event_count == 1) { + ss << "Duration: " << setw(6) << right << static_cast(timespan.avg) << "ms"; + } else { + ss << "Average: " << setw(6) << right << static_cast(timespan.avg) + << "ms Minimum: " << setw(6) << right << timespan.min << "ms Maximum: " << setw(6) + << right << timespan.max << "ms"; + } + + report.push_back(ss.str()); + } + + return report; +} + +vector generate_event_report( + const vector &timepoints, int max_name_length) +{ + vector report; + + Stopwatch::time_unit last = Stopwatch::start_time; + for (const auto &timepoint : timepoints) { + stringstream ss; + + int64_t since_start = chrono::duration_cast( + timepoint.time_point - Stopwatch::start_time) + .count(); + int64_t since_last = + chrono::duration_cast(timepoint.time_point - last).count(); + + ss << setw(max_name_length) << left << timepoint.event_name << ": " << setw(6) << right + << since_start << "ms since start, " << setw(6) << right << since_last + << "ms since last single event."; + last = timepoint.time_point; + report.push_back(ss.str()); + } + + return report; +} + +void print_timing_report(const Stopwatch &stopwatch) +{ + vector timing_report; + vector timings; + stopwatch.get_timespans(timings); + + if (timings.size() > 0) { + timing_report = + generate_timespan_report(timings, stopwatch.get_max_timespan_event_name_length()); + + APSI_LOG_INFO("Timespan event information"); + for (const auto &timing : timing_report) { + APSI_LOG_INFO(timing.c_str()); + } + } + + vector timepoints; + stopwatch.get_events(timepoints); + + if (timepoints.size() > 0) { + timing_report = generate_event_report(timepoints, stopwatch.get_max_event_name_length()); + + APSI_LOG_INFO("Single event information"); + for (const auto &timing : timing_report) { + APSI_LOG_INFO(timing.c_str()); + } + } +} + +void throw_if_file_invalid(const string &file_name) +{ + fs::path file(file_name); + + if (!fs::exists(file)) { + APSI_LOG_ERROR("File `" << file.string() << "` does not exist"); + throw logic_error("file does not exist"); + } + if (!fs::is_regular_file(file)) { + APSI_LOG_ERROR("File `" << file.string() << "` is not a regular file"); + throw logic_error("invalid file"); + } +} diff --git a/src/common_utils.h b/src/common_utils.h new file mode 100644 index 0000000..f714f2f --- /dev/null +++ b/src/common_utils.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +// STD +#include +#include + +// APSI +#include + +/** +Prepare console for color output. +*/ +void prepare_console(); + +/** +Generate timing report for timespans. +*/ +std::vector generate_timespan_report( + const std::vector ×pans, int max_name_length); + +/** +Generate timing report for single events. +*/ +std::vector generate_event_report( + const std::vector &timepoints, int max_name_length); + +/** +Print timings. +*/ +void print_timing_report(const apsi::util::Stopwatch &stopwatch); + +/** +Throw an exception if the given file is invalid. +*/ +void throw_if_file_invalid(const std::string &file_name); diff --git a/src/csv_reader.cpp b/src/csv_reader.cpp new file mode 100644 index 0000000..4c8f2d3 --- /dev/null +++ b/src/csv_reader.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +// STD +#include +#include +#include +#include +#include + +// APSI +#include +#include "common_utils.h" +#include "csv_reader.h" + +using namespace std; +using namespace apsi; +using namespace apsi::util; + +CSVReader::CSVReader() +{} + +CSVReader::CSVReader(const string &file_name) : file_name_(file_name) +{ + throw_if_file_invalid(file_name_); +} + +auto CSVReader::read(istream &stream) const -> pair> +{ + string line; + DBData result; + vector orig_items; + + if (!getline(stream, line)) { + APSI_LOG_WARNING("Nothing to read in `" << file_name_ << "`"); + return { UnlabeledData{}, {} }; + } else { + string orig_item; + Item item; + Label label; + auto [has_item, has_label] = process_line(line, orig_item, item, label); + + if (!has_item) { + APSI_LOG_WARNING("Failed to read item from `" << file_name_ << "`"); + return { UnlabeledData{}, {} }; + } + + orig_items.push_back(move(orig_item)); + if (has_label) { + result = LabeledData{ make_pair(move(item), move(label)) }; + } else { + result = UnlabeledData{ move(item) }; + } + } + + while (getline(stream, line)) { + string orig_item; + Item item; + Label label; + auto [has_item, _] = process_line(line, orig_item, item, label); + + if (!has_item) { + // Something went wrong; skip this item and move on to the next + APSI_LOG_WARNING("Failed to read item from `" << file_name_ << "`"); + continue; + } + + orig_items.push_back(move(orig_item)); + if (holds_alternative(result)) { + get(result).push_back(move(item)); + } else if (holds_alternative(result)) { + get(result).push_back(make_pair(move(item), move(label))); + } else { + // Something is terribly wrong + APSI_LOG_ERROR("Critical error reading data"); + throw runtime_error("variant is in bad state"); + } + } + + return { move(result), move(orig_items) }; +} + +auto CSVReader::read() const -> pair> +{ + throw_if_file_invalid(file_name_); + + ifstream file(file_name_); + if (!file.is_open()) { + APSI_LOG_ERROR("File `" << file_name_ << "` could not be opened for reading"); + throw runtime_error("could not open file"); + } + + return read(file); +} + +pair CSVReader::process_line( + const string &line, string &orig_item, Item &item, Label &label) const +{ + stringstream ss(line); + string token; + + // First is the item + getline(ss, token, ','); + + // Trim leading whitespace + token.erase( + token.begin(), find_if(token.begin(), token.end(), [](int ch) { return !isspace(ch); })); + + // Trim trailing whitespace + token.erase( + find_if(token.rbegin(), token.rend(), [](int ch) { return !isspace(ch); }).base(), + token.end()); + + if (token.empty()) { + // Nothing found + return { false, false }; + } + + // Item can be of arbitrary length; the constructor of Item will automatically hash it + orig_item = token; + item = token; + + // Second is the label + token.clear(); + getline(ss, token); + + // Trim leading whitespace + token.erase( + token.begin(), find_if(token.begin(), token.end(), [](int ch) { return !isspace(ch); })); + + // Trim trailing whitespace + token.erase( + find_if(token.rbegin(), token.rend(), [](int ch) { return !isspace(ch); }).base(), + token.end()); + + label.clear(); + label.reserve(token.size()); + copy(token.begin(), token.end(), back_inserter(label)); + + return { true, !token.empty() }; +} diff --git a/src/csv_reader.h b/src/csv_reader.h new file mode 100644 index 0000000..79d1d62 --- /dev/null +++ b/src/csv_reader.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +// STD +#include +#include +#include +#include +#include +#include + +// APSI +#include +#include +#include + +/** +Simple CSV file parser +*/ +class CSVReader { +public: + using UnlabeledData = std::vector; + + using LabeledData = std::vector>; + + using DBData = std::variant; + + CSVReader(); + + CSVReader(const std::string &file_name); + + std::pair> read(std::istream &stream) const; + + std::pair> read() const; + +private: + std::string file_name_; + + std::pair process_line( + const std::string &line, + std::string &orig_item, + apsi::Item &item, + apsi::Label &label) const; +}; // class CSVReader diff --git a/src/main.cpp b/src/main.cpp index e68f60b..951178d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,6 +46,7 @@ #include #include #include +#include "sender.h" using namespace std; using namespace apsi; @@ -260,6 +261,19 @@ class APSIServer } } + void load_csv_db(const string &csv_db_file_path, const string ¶ms_json, + size_t nonce_byte_count, bool compressed) + { + try + { + _db = try_load_csv_db(csv_db_file_path,params_json, nonce_byte_count, compressed); + } + catch(const exception &e) + { + throw runtime_error("Failed to load data from a CSV file."); + } + } + void add_item(const string &input_item, const string &input_label) { Item item(input_item); @@ -284,6 +298,23 @@ class APSIServer _db->insert_or_assign(items); } + void add_labeled_items(const py::iterable &input_items_with_label) + { + vector> items_with_label; + for (py::handle handler : input_items_with_label){ + py::tuple py_tup = handler.cast(); + if(py::len(py_tup)!=2){ + throw runtime_error("data error, item_with_label should be a tuple with size 2."); + } + string label_str = py_tup[1].cast(); + items_with_label.push_back(make_pair( + Item(py_tup[0].cast()), + Label(label_str.begin(), label_str.end()) + )); + } + _db->insert_or_assign(items_with_label); + } + py::bytes handle_oprf_request(const string &oprf_request_string) { _channel.set_in_buffer(oprf_request_string); @@ -335,8 +366,10 @@ PYBIND11_MODULE(_pyapsi, m) .def("_init_db", &APSIServer::init_db) .def("_save_db", &APSIServer::save_db) .def("_load_db", &APSIServer::load_db) + .def("_load_csv_db", &APSIServer::load_csv_db) .def("_add_item", &APSIServer::add_item) .def("_add_unlabeled_items", &APSIServer::add_unlabeled_items) + .def("_add_labeled_items", &APSIServer::add_labeled_items) .def("_handle_oprf_request", &APSIServer::handle_oprf_request) .def("_handle_query", &APSIServer::handle_query) // TODO: use def_property_readonly instead diff --git a/src/sender.cpp b/src/sender.cpp new file mode 100644 index 0000000..113515d --- /dev/null +++ b/src/sender.cpp @@ -0,0 +1,112 @@ +#include "sender.h" + +using namespace std; +using namespace apsi; +using namespace apsi::oprf; +using namespace apsi::sender; + +unique_ptr db_data_from_csv(const string &db_file) +{ + CSVReader::DBData db_data; + try { + CSVReader reader(db_file); + tie(db_data, ignore) = reader.read(); + } catch (const exception &ex) { + APSI_LOG_WARNING("Could not open or read file `" << db_file << "`: " << ex.what()); + return nullptr; + } + + return make_unique(move(db_data)); +} + +shared_ptr try_load_csv_db( + const string &db_file_path, + const string ¶ms_json, + size_t nonce_byte_count, + bool compressed) +{ + unique_ptr params; + try { + params = make_unique(PSIParams::Load(params_json)); + } catch (const exception &ex) { + APSI_LOG_ERROR("APSI threw an exception creating PSIParams: " << ex.what()); + return nullptr; + } + + if (!params) { + // We must have valid parameters given + APSI_LOG_ERROR("Failed to set PSI parameters"); + return nullptr; + } + + unique_ptr db_data; + if (db_file_path.empty() || !(db_data = db_data_from_csv(db_file_path))) { + // Failed to read db file + APSI_LOG_DEBUG("Failed to load data from a CSV file"); + return nullptr; + } + + return create_sender_db( + *db_data, move(params), nonce_byte_count, compressed); +} + +shared_ptr create_sender_db( + const CSVReader::DBData &db_data, + unique_ptr psi_params, + size_t nonce_byte_count, + bool compress) +{ + if (!psi_params) { + APSI_LOG_ERROR("No PSI parameters were given"); + return nullptr; + } + + shared_ptr sender_db; + if (holds_alternative(db_data)) { + try { + sender_db = make_shared(*psi_params, 0, 0, compress); + sender_db->set_data(get(db_data)); + + APSI_LOG_INFO( + "Created unlabeled SenderDB with " << sender_db->get_item_count() << " items"); + } catch (const exception &ex) { + APSI_LOG_ERROR("Failed to create SenderDB: " << ex.what()); + return nullptr; + } + } else if (holds_alternative(db_data)) { + try { + auto &labeled_db_data = get(db_data); + + // Find the longest label and use that as label size + size_t label_byte_count = + max_element(labeled_db_data.begin(), labeled_db_data.end(), [](auto &a, auto &b) { + return a.second.size() < b.second.size(); + })->second.size(); + + sender_db = + make_shared(*psi_params, label_byte_count, nonce_byte_count, compress); + sender_db->set_data(labeled_db_data); + APSI_LOG_INFO( + "Created labeled SenderDB with " << sender_db->get_item_count() << " items and " + << label_byte_count << "-byte labels (" + << nonce_byte_count << "-byte nonces)"); + } catch (const exception &ex) { + APSI_LOG_ERROR("Failed to create SenderDB: " << ex.what()); + return nullptr; + } + } else { + // Should never reach this point + APSI_LOG_ERROR("Loaded database is in an invalid state"); + return nullptr; + } + + if (compress) { + APSI_LOG_INFO("Using in-memory compression to reduce memory footprint"); + } + + // Read the OPRFKey and strip the SenderDB to reduce memory use , Not NOW + //oprf_key = sender_db->strip(); + APSI_LOG_INFO("SenderDB packing rate: " << sender_db->get_packing_rate()); + + return sender_db; +} diff --git a/src/sender.h b/src/sender.h new file mode 100644 index 0000000..3172b4a --- /dev/null +++ b/src/sender.h @@ -0,0 +1,23 @@ +#include +#include +#include +#include +#include + +#include "csv_reader.h" + + + +std::unique_ptr db_data_from_csv(const std::string &db_file); + +std::shared_ptr try_load_csv_db( + const std::string &db_file_path, + const std::string ¶ms_json, + size_t nonce_byte_count, + bool compressed); + +std::shared_ptr create_sender_db( + const CSVReader::DBData &db_data, + std::unique_ptr psi_params, + size_t nonce_byte_count, + bool compress); \ No newline at end of file diff --git a/tests/load_csv_query.py b/tests/load_csv_query.py new file mode 100644 index 0000000..eda021e --- /dev/null +++ b/tests/load_csv_query.py @@ -0,0 +1,37 @@ +from apsi import LabeledServer, LabeledClient + +apsi_params = """ +{ + "table_params": { + "hash_func_count": 3, + "table_size": 1638, + "max_items_per_bin": 1304 + }, + "item_params": { + "felts_per_item": 5 + }, + "query_params": { + "ps_low_degree": 44, + "query_powers": [ 1, 3, 11, 18, 45, 225 ] + }, + "seal_params": { + "plain_modulus_bits": 22, + "poly_modulus_degree": 8192, + "coeff_modulus_bits": [ 56, 56, 56, 50 ] + } +} +""" +def main() -> None: + server = LabeledServer() + server.load_csv_db('./tests/test_10.csv',apsi_params) + + client = LabeledClient(apsi_params) + oprf_request = client.oprf_request(["828123436896012688", "952535141803615208"]) + oprf_response = server.handle_oprf_request(oprf_request) + query = client.build_query(oprf_response) + response = server.handle_query(query) + result = client.extract_result(response) + print(result) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_10.csv b/tests/test_10.csv new file mode 100644 index 0000000..2e378b3 --- /dev/null +++ b/tests/test_10.csv @@ -0,0 +1,11 @@ +ID,REG_DATE,AGE,CONSUME +562484070742182305,2006-01-09,33,966538.02 +693716548002263171,2019-10-22,32,710751.65 +517921705480501278,2013-01-31,70,575629.59 +119448901516529428,2018-02-04,23,38907.32 +259765066479929688,2011-03-28,60,588679.35 +828123436896012688,2020-12-07,72,926631.74 +727949477247804616,2005-06-08,29,196800.06 +166601997080595901,2015-05-21,64,329460.16 +123357540266750140,2018-01-15,79,676582.12 +952535141803615208,2017-01-15,27,143605.97 diff --git a/tests/test_base_fun.py b/tests/test_base_fun.py new file mode 100644 index 0000000..89f25bf --- /dev/null +++ b/tests/test_base_fun.py @@ -0,0 +1,35 @@ +from apsi import LabeledServer, LabeledClient + +apsi_params = """ +{ + "table_params": { + "hash_func_count": 3, + "table_size": 512, + "max_items_per_bin": 92 + }, + "item_params": {"felts_per_item": 8}, + "query_params": { + "ps_low_degree": 0, + "query_powers": [1, 3, 4, 5, 8, 14, 20, 26, 32, 38, 41, 42, 43, 45, 46] + }, + "seal_params": { + "plain_modulus": 40961, + "poly_modulus_degree": 4096, + "coeff_modulus_bits": [40, 32, 32] + } +} +""" + +server = LabeledServer() +server.init_db(apsi_params, max_label_length=10) +server.add_items([("item","1234567890"), ("abc", "123"), ("other", "my label")]) + +client = LabeledClient(apsi_params) + +oprf_request = client.oprf_request(["item", "abc"]) +oprf_response = server.handle_oprf_request(oprf_request) +query = client.build_query(oprf_response) +response = server.handle_query(query) +result = client.extract_result(response) + +assert result == {"item": "1234567890", "abc": "123"} \ No newline at end of file