Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 19 additions & 57 deletions packages/client/libclient-py/quickmpc/share/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,72 +23,34 @@ def get_list(self, a, b, size: int) -> List[int]:

@dataclass(frozen=True)
class ChaCha20(RandomInterface):
def __exception_check(self, a, b) -> None:
if a >= b:
raise ArgumentError(
"乱数の下限は上限より小さい必要があります."
f"{a} < {b}")
if type(a) != type(b):
raise ArgumentError(
"乱数の下限と上限の型は一致させる必要があります."
f"{type(a)} != {type(b)}")

# 128bit符号付き整数最大,最小値
mx: ClassVar[int] = (1 << 128)-1
mn: ClassVar[int] = -(1 << 128)
def __get_byte_size(self, x: int) -> int:
# 整数の byte サイズを取得
return max(math.ceil(math.log2(x))//8 + 1, 32)

@methoddispatch()
def get(self, a, b):
raise ArgumentError(
"乱数の閾値はどちらもintもしくはdecimalでなければなりません."
f"a is {type(a)}, b is {type(b)}")

@get.register(int)
def __get_int(self, a: int, b: int) -> int:
# TRNGで [a,b) の乱数生成
def get(self, a, b) -> int:
self.__exception_check(a, b)
interval_byte = self.__get_byte_size(b-a)
byte_val: bytes = random(interval_byte)
int_val = int.from_bytes(byte_val, "big")
return int_val % (b - a) + a

@get.register(Decimal)
def __get_decimal(self, a: Decimal, b: Decimal) -> Decimal:
# 256bit整数を取り出して[a,b]に正規化する
self.__exception_check(a, b)
val: int = self.get(self.mn, self.mx)
return Decimal(val-self.mn)/(self.mx-self.mn)*(b-a)+a

@methoddispatch()
def get_list(self, a, b, size: int):
raise ArgumentError(
"乱数の閾値はどちらもintもしくはdecimalでなければなりません."
f"a is {type(a)}, b is {type(b)}")

@get_list.register(int)
def __get_list_int(self, a: int, b: int, size: int) -> List[int]:
# TRNGの32byteをseedとしてCSPRNGでsize分生成
byte_size: int = self.__get_byte_size(b-a)
self.__exception_check(a, b)
seed: bytes = self.__get_32byte()
bytes_list: bytes = randombytes_deterministic(size*byte_size, seed)
int_list = [int.from_bytes(bytes_list[i:i+byte_size], "big")
for i in range(0, len(bytes_list), byte_size)]
return [x % (b-a)+a for x in int_list]

@get_list.register(Decimal)
def __get_list_decimal(self, a: Decimal, b: Decimal, size: int) \
-> List[Decimal]:
# 128bit整数を取り出して[a,b]に正規化する
self.__exception_check(a, b)
valList: List[int] = self.get_list(self.mn, self.mx, size)
return [Decimal(val-self.mn)/(self.mx-self.mn)*(b-a)+a
for val in valList]

def __get_byte_size(self, x: int) -> int:
# 整数の byte サイズを取得
return max(math.ceil(math.log2(x))//8 + 1, 32)

def __get_32byte(self) -> bytes:
return random()

def __exception_check(self, a, b) -> None:
if a >= b:
raise ArgumentError(
"乱数の下限は上限より小さい必要があります."
f"{a} < {b}")
if type(a) != type(b):
raise ArgumentError(
"乱数の下限と上限の型は一致させる必要があります."
f"{type(a)} != {type(b)}")
interval_byte = self.__get_byte_size(b-a)
byte_list: bytes = random(interval_byte * size)
int_list = [int.from_bytes(byte_list[i:i+interval_byte], "big")
for i in range(0, len(byte_list), interval_byte)]
return [int_val % (b - a) + a for int_val in int_list]

177 changes: 30 additions & 147 deletions packages/client/libclient-py/quickmpc/share/share.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,173 +18,50 @@

@dataclass(frozen=True)
class Share:
__share_random_range: ClassVar[Tuple[Decimal, Decimal]] =\
(Decimal(-(1 << 64)), Decimal(1 << 64))
__SHIFT_VAL = 10**8
__SECRET_RANGE = 10**19 * __SHIFT_VAL

@methoddispatch(is_static_method=True)
@staticmethod
def __to_str(_):
logger.error("Invalid argument on stringfy.")
raise ArgumentError("不正な引数が与えられています.")

@__to_str.register(Decimal)
@staticmethod
def __decimal_to_str(val: Decimal) -> str:
# InfinityをCCで読み込めるinfに変換
return 'inf' if Decimal.is_infinite(val) else str(val)

@__to_str.register(int)
@staticmethod
def __int_to_str(val: int) -> str:
return str(val)
Scalar = Union[int, float, Decimal]

@methoddispatch(is_static_method=True)
@staticmethod
def sharize(_, __):
logger.error("Invalid argument on sharize.")
raise ArgumentError("不正な引数が与えられています.")

@methoddispatch(is_static_method=True)
@staticmethod
def recons(_):
logger.error("Invalid argument on recons.")
raise ArgumentError("不正な引数が与えられています.")

@methoddispatch(is_static_method=True)
@sharize.register(Scalar)
@staticmethod
def convert_type(_, __):
logger.error("Invalid argument on convert_type.")
raise ArgumentError("不正な引数が与えられています.")

@sharize.register(int)
@sharize.register(float)
@staticmethod
def __sharize_scalar(secrets: float, party_size: int = 3) -> List[str]:
def __sharize_scalar(secret: Scalar, party_size: int = 3) -> List[str]:
""" スカラ値のシェア化 """
secret *= Share.__SHIFT_VAL
if abs(secret) > Share.__SECRET_RANGE:
logger.error("Out of range")
raise ArgumentError("Out of range")
rnd: RandomInterface = ChaCha20()
shares: List[int] = rnd.get_list(
*Share.__share_random_range, party_size)
shares[0] += Decimal(secrets) - np.sum(shares)
shares_str: List[str] = [str(n) for n in shares]
return shares_str
shares: List[Decimal] = rnd.get_list(-Share.__SECRET_RANGE, Share.__SECRET_RANGE, party_size)
shares[0] += Decimal(secret) - np.sum(shares)
return [str(n / Share.__SHIFT_VAL) for n in shares]

@sharize.register((Dim1, float))
@sharize.register(List)
@staticmethod
def __sharize_1dimension_float(secrets: List[Union[float, Decimal]],
party_size: int = 3) \
-> List[List[str]]:
""" 1次元リストのシェア化 """
rnd: RandomInterface = ChaCha20()
secrets_size: int = len(secrets)
shares: np.ndarray = np.array([
rnd.get_list(*Share.__share_random_range, secrets_size)
for __ in range(party_size - 1)])
s1: np.ndarray = np.subtract(np.frompyfunc(Decimal, 1, 1)(secrets),
np.sum(shares, axis=0))
shares_str: List[List[str]] = \
np.vectorize(Share.__to_str)([s1, *shares]).tolist()
return shares_str
def __sharize_multidim(secrets: List[Scalar], party_size: int = 3):
return [Share.__sharize(secret, party_size) for secret in secrets]

@sharize.register((Dim1, Decimal))
@staticmethod
def __sharize_1dimension_decimal(secrets: List[Decimal],
party_size: int = 3) \
-> List[List[str]]:
return Share.__sharize_1dimension_float(secrets, party_size)

@sharize.register((Dim1, int))
@staticmethod
def __sharize_1dimension_int(secrets: List[int], party_size: int = 3) \
-> List[List[str]]:
""" 1次元リストのシェア化 """
rnd: RandomInterface = ChaCha20()
secrets_size: int = len(secrets)
max_val = (max(secrets)+1) * 2
shares: np.ndarray = np.array([
rnd.get_list(-max_val, max_val, secrets_size)
for __ in range(party_size - 1)])
s1: np.ndarray = np.subtract(np.frompyfunc(int, 1, 1)(secrets),
np.sum(shares, axis=0))
shares_str: List[List[str]] = np.vectorize(
Share.__to_str)([s1, *shares]).tolist()
return shares_str

@sharize.register(Dim2)
@staticmethod
def __sharize_2dimension(secrets: List[List[Union[float, int]]],
party_size: int = 3) -> List[List[List[str]]]:
""" 2次元リストのシェア化 """
transposed: List[Union[List[int], List[float]]] \
= np.array(secrets, dtype=object).transpose().tolist()
dst: List[List[List[str]]] = [
Share.sharize(col, party_size) for col in transposed
]
dst = np.array(dst, dtype=object).transpose(1, 2, 0).tolist()

return dst

@sharize.register(dict)
@staticmethod
def __sharize_dict(secrets: dict, party_size: int = 3) -> List[dict]:
""" 辞書型のシェア化 """
shares_str: List[dict] = [dict() for _ in range(party_size)]
for key, val in secrets.items():
for i, share_val in enumerate(Share.sharize(val, party_size)):
shares_str[i][key] = share_val
return shares_str

@sharize.register(DictList)
@methoddispatch(is_static_method=True)
@staticmethod
def __sharize_dictlist(secrets: dict, party_size: int = 3) \
-> List[List[dict]]:
""" 辞書型配列のシェア化 """
shares_str: List[List[dict]] = [[] for _ in range(party_size)]
for secret_dict in secrets:
share_dict: List[dict] = Share.sharize(secret_dict, party_size)
for ss, sd in zip(shares_str, share_dict):
ss.append(sd)
return shares_str
def recons(_):
logger.error("Invalid argument on sharize.")
raise ArgumentError("不正な引数が与えられています.")

@recons.register(Dim1)
@recons.register(List[Scalar])
@staticmethod
def __recons_list1(shares: List[Union[int, Decimal]]):
""" 1次元リストのシェアを復元 """
def __recons_scalar(shares: List[Scalar]):
return sum(shares)

@recons.register(Dim2)
@recons.register(Dim3)
@recons.register(List[List])
@staticmethod
def __recons_list(shares: List[List[Union[int, Decimal]]]) -> List:
""" リストのシェアを復元 """
secrets: List = [
Share.recons([shares_pi[i] for shares_pi in shares])
for i in range(len(shares[0]))
]
return secrets

@recons.register(DictList)
@staticmethod
def __recons_dictlist(shares: List[dict]) -> dict:
""" 辞書型を復元 """
secrets: dict = dict()
for key in shares[0].keys():
val = []
for s in shares:
val.append(s[key])
secrets[key] = Share.recons(val)
return secrets

@recons.register(DictList2)
@staticmethod
def __recons_dictlist2(shares: List[List[dict]]) -> list:
""" 辞書型配列を復元 """
secrets: list = list()
for i in range(len(shares[0])):
val = []
for s in shares:
val.append(s[i])
secrets.append(Share.recons(val))
return secrets
def __recons_multidim(multidim_shares: List[List]) -> List:
return [Share.recons(shares) for shares in multidim_shares]

@staticmethod
def get_pre_convert_func(
Expand Down Expand Up @@ -226,6 +103,12 @@ def get_convert_func(
return Share.convert_int_to_str
return float

@methoddispatch(is_static_method=True)
@staticmethod
def convert_type(_, __):
logger.error("Invalid argument on convert_type.")
raise ArgumentError("不正な引数が与えられています.")

@convert_type.register(str)
@staticmethod
def __pre_convert_type_str(
Expand Down
Loading