Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ __pycache__/
*.whl
external/
*.so
.vscode/
29 changes: 17 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
cmake_minimum_required(VERSION 3.13.4)

project(_pyapsi)
find_package(Threads REQUIRED)
find_package(SEAL 4.1 REQUIRED)
find_package(APSI 0.11.0 REQUIRED)
#find_package(jsoncpp REQUIRED)

include(FetchContent)
# include(FetchContent)

set(CMAKE_POSITION_INDEPENDENT_CODE ON)
find_package(pybind11 REQUIRED)
#FetchContent_Declare(
# pybind11
# GIT_REPOSITORY https://github.com/pybind/pybind11.git
# GIT_TAG v2.9.2
#)
# FetchContent_MakeAvailable(pybind11)

FetchContent_Declare(
pybind11
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG v2.9.2
)
FetchContent_MakeAvailable(pybind11)
# add_subdirectory(external/apsi/)
set(MAIN_SOURCES src/sender.cpp src/common_utils.cpp src/csv_reader.cpp src/main.cpp)
set(MAIN_HEADERS src/sender.h src/common_utils.h src/csv_reader.h src/base_clp.h )

add_subdirectory(external/apsi/)
pybind11_add_module(_pyapsi ${MAIN_SOURCES} ${MAIN_HEADERS})

pybind11_add_module(_pyapsi src/main.cpp)

target_link_libraries(_pyapsi PRIVATE pybind11::module apsi)
target_link_libraries(_pyapsi PRIVATE pybind11::module APSI::apsi SEAL::seal)

target_compile_definitions(_pyapsi PRIVATE)
15 changes: 12 additions & 3 deletions apsi/servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def load_db(self, db_file_path: str) -> None:

self._load_db(db_file_path)
self.db_initialized = True

def load_csv_db(self,csv_db_file_path:str, params_json:str,
nonce_byte_count: int = 16,
compressed: bool = False
) -> None:
"""Load a database from csv file."""
p = Path(csv_db_file_path)
if not p.exists():
raise FileNotFoundError(f"DB file does not exist: {p}")
self._load_csv_db(csv_db_file_path, params_json, nonce_byte_count, compressed)
self.db_initialized = True

def handle_oprf_request(self, oprf_request: bytes) -> bytes:
"""Handle an initial APSI Client OPRF request.
Expand Down Expand Up @@ -109,9 +120,7 @@ def add_items(self, items_with_label: Iterable[Tuple[str, str]]) -> None:
value is the label.
"""
self._requires_db()
# TODO: Expose batch add in C++ PyAPSI; and add length checks accordingly
for item, label in items_with_label:
self.add_item(item=item, label=label)
self._add_labeled_items(items_with_label)


class UnlabeledServer(_BaseServer):
Expand Down
130 changes: 130 additions & 0 deletions src/base_clp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once

// STD
#include <memory>
#include <string>

// TCLAP
#ifdef _MSC_VER
#pragma warning(push, 0)
#endif
#include <tclap/CmdLine.h>
#ifdef _MSC_VER
#pragma warning(pop)
#endif

// APSI
#include <apsi/log.h>


/**
Command line processor based on TCLAP. This is a base class that contains common arguments for both
parties.
*/
class BaseCLP : public TCLAP::CmdLine {
public:
BaseCLP(const std::string &description, const std::string &version)
: TCLAP::CmdLine(description, /* delim */ ' ', version)
{
std::vector<std::string> log_levels = { "all", "debug", "info", "warning", "error", "off" };
log_level_constraint_ = std::make_unique<TCLAP::ValuesConstraint<std::string>>(log_levels);
log_level_arg_ = std::make_unique<TCLAP::ValueArg<std::string>>(
"l",
"logLevel",
"One of \"all\", \"debug\", \"info\" (default), \"warning\", \"error\", \"off\"",
false,
"info",
log_level_constraint_.get(),
*this);
}

virtual ~BaseCLP()
{}

/**
Add additional arguments to the Command Line Processor.
*/
virtual void add_args() = 0;

/**
Get the value of the additional arguments.
*/
virtual void get_args() = 0;

bool parse_args(int argc, char **argv)
{
TCLAP::ValueArg<std::size_t> threads_arg(
"t",
"threads",
"Number of threads to use",
/* req */ false,
/* value */ 0,
/* type desc */ "unsigned integer");
add(threads_arg);

TCLAP::ValueArg<std::string> logfile_arg(
"f", "logFile", "Log file path", false, "", "file path");
add(logfile_arg);

TCLAP::SwitchArg silent_arg("s", "silent", "Do not write output to console", false);
add(silent_arg);

// No need to add log_level_arg_, already added in constructor

// Additional arguments
add_args();

try {
parse(argc, argv);

silent_ = silent_arg.getValue();
log_file_ = logfile_arg.getValue();
threads_ = threads_arg.getValue();
log_level_ = log_level_arg_->getValue();

apsi::Log::SetConsoleDisabled(silent_);
apsi::Log::SetLogFile(log_file_);
apsi::Log::SetLogLevel(log_level_);

get_args();
} catch (...) {
return false;
}

return true;
}

std::size_t threads() const
{
return threads_;
}

const std::string &log_level() const
{
return log_level_;
}

const std::string &log_file() const
{
return log_file_;
}

bool silent() const
{
return silent_;
}

private:
// Parameters from command line
std::size_t threads_;
std::string log_level_;
std::string log_file_;
bool silent_;

// Parameters with constraints
std::unique_ptr<TCLAP::ValueArg<std::string>> log_level_arg_;
std::unique_ptr<TCLAP::ValuesConstraint<std::string>> log_level_constraint_;
};
144 changes: 144 additions & 0 deletions src/common_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "common_utils.h"

// STD
#include <iomanip>
#include <iostream>
#if defined(_MSC_VER)
#include <windows.h>
#endif
#if defined(__GNUC__) && (__GNUC__ < 8) && !defined(__clang__)
#include <experimental/filesystem>
#else
#include <filesystem>
#endif

// APSI
#include <apsi/log.h>
#include <apsi/psi_params.h>
#include <apsi/util/utils.h>
#include "base_clp.h"

using namespace std;
#if defined(__GNUC__) && (__GNUC__ < 8) && !defined(__clang__)
namespace fs = std::experimental::filesystem;
#else
namespace fs = std::filesystem;
#endif
using namespace seal;
using namespace apsi;
using namespace apsi::util;

/**
This only turns on showing colors for Windows.
*/
void prepare_console()
{
#ifndef _MSC_VER
return; // Nothing to do on Linux.
#else
HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
if (hConsole == INVALID_HANDLE_VALUE)
return;

DWORD dwMode = 0;
if (!GetConsoleMode(hConsole, &dwMode))
return;

dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
SetConsoleMode(hConsole, dwMode);
#endif
}

vector<string> generate_timespan_report(
const vector<Stopwatch::TimespanSummary> &timespans, int max_name_length)
{
vector<string> report;

for (const auto &timespan : timespans) {
stringstream ss;
ss << setw(max_name_length) << left << timespan.event_name << ": " << setw(5) << right
<< timespan.event_count << " instances. ";
if (timespan.event_count == 1) {
ss << "Duration: " << setw(6) << right << static_cast<int>(timespan.avg) << "ms";
} else {
ss << "Average: " << setw(6) << right << static_cast<int>(timespan.avg)
<< "ms Minimum: " << setw(6) << right << timespan.min << "ms Maximum: " << setw(6)
<< right << timespan.max << "ms";
}

report.push_back(ss.str());
}

return report;
}

vector<string> generate_event_report(
const vector<Stopwatch::Timepoint> &timepoints, int max_name_length)
{
vector<string> report;

Stopwatch::time_unit last = Stopwatch::start_time;
for (const auto &timepoint : timepoints) {
stringstream ss;

int64_t since_start = chrono::duration_cast<chrono::milliseconds>(
timepoint.time_point - Stopwatch::start_time)
.count();
int64_t since_last =
chrono::duration_cast<chrono::milliseconds>(timepoint.time_point - last).count();

ss << setw(max_name_length) << left << timepoint.event_name << ": " << setw(6) << right
<< since_start << "ms since start, " << setw(6) << right << since_last
<< "ms since last single event.";
last = timepoint.time_point;
report.push_back(ss.str());
}

return report;
}

void print_timing_report(const Stopwatch &stopwatch)
{
vector<string> timing_report;
vector<Stopwatch::TimespanSummary> timings;
stopwatch.get_timespans(timings);

if (timings.size() > 0) {
timing_report =
generate_timespan_report(timings, stopwatch.get_max_timespan_event_name_length());

APSI_LOG_INFO("Timespan event information");
for (const auto &timing : timing_report) {
APSI_LOG_INFO(timing.c_str());
}
}

vector<Stopwatch::Timepoint> timepoints;
stopwatch.get_events(timepoints);

if (timepoints.size() > 0) {
timing_report = generate_event_report(timepoints, stopwatch.get_max_event_name_length());

APSI_LOG_INFO("Single event information");
for (const auto &timing : timing_report) {
APSI_LOG_INFO(timing.c_str());
}
}
}

void throw_if_file_invalid(const string &file_name)
{
fs::path file(file_name);

if (!fs::exists(file)) {
APSI_LOG_ERROR("File `" << file.string() << "` does not exist");
throw logic_error("file does not exist");
}
if (!fs::is_regular_file(file)) {
APSI_LOG_ERROR("File `" << file.string() << "` is not a regular file");
throw logic_error("invalid file");
}
}
38 changes: 38 additions & 0 deletions src/common_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once

// STD
#include <string>
#include <vector>

// APSI
#include <apsi/util/stopwatch.h>

/**
Prepare console for color output.
*/
void prepare_console();

/**
Generate timing report for timespans.
*/
std::vector<std::string> generate_timespan_report(
const std::vector<apsi::util::Stopwatch::TimespanSummary> &timespans, int max_name_length);

/**
Generate timing report for single events.
*/
std::vector<std::string> generate_event_report(
const std::vector<apsi::util::Stopwatch::Timepoint> &timepoints, int max_name_length);

/**
Print timings.
*/
void print_timing_report(const apsi::util::Stopwatch &stopwatch);

/**
Throw an exception if the given file is invalid.
*/
void throw_if_file_invalid(const std::string &file_name);
Loading