"""Module for information-theoretic and pseudorandom threshold secret sharing. Threshold secret sharing assumes secure channels for communication. Pseudorandom secret sharing (PRSS) allows one to share pseudorandom secrets without any communication, as long as the parties agree on a (unique) common public input for each secret. PRSS relies on parties having agreed upon the keys for a pseudorandom function (PRF). """ __all__ = ['random_split', 'recombine', 'pseudorandom_share', 'pseudorandom_share_zero', 'PRF'] import functools import hashlib import secrets def random_split(s, t, m): """Split each secret given in s into m random Shamir shares. The (maximum) degree for the Shamir polynomials is t, 0 <= t < n. Return matrix of shares, one row per party. """ field = type(s[0]) p = field.modulus order = field.order T = type(p) # T is int or gfpx.Polynomial n = len(s) shares = [[None] * n for _ in range(m)] for h in range(n): c = [secrets.randbelow(order) for _ in range(t)] # polynomial f(X) = s[h] + c[t-1] X + c[t-2] X^2 + ... + c[0] X^t for i in range(m): y = 0 if T is int else T(0) for c_j in c: y += c_j y *= i+1 shares[i][h] = (y + s[h].value) % p return shares @functools.lru_cache(maxsize=None) def _recombination_vector(field, xs, x_r): """Compute and store a recombination vector. A recombination vector depends on the field, the x-coordinates xs of the shares and the x-coordinate x_r of the recombination point. """ vector = [] for i, x_i in enumerate(xs): x_i = field(x_i) coefficient = field(1) for j, x_j in enumerate(xs): x_j = field(x_j) if i != j: coefficient *= (x_r - x_j) / (x_i - x_j) vector.append(coefficient.value) return vector def recombine(field, points, x_rs=0): """Recombine shares given by points into secrets. Recombination is done for x-coordinates x_rs. """ xs, shares = list(zip(*points)) if not isinstance(x_rs, list): x_rs = (x_rs,) m = len(shares) n = len(shares[0]) width = len(x_rs) T_is_field = isinstance(shares[0][0], field) # all elts assumed of same type vector = [_recombination_vector(field, xs, x_r) for x_r in x_rs] sums = [[0] * n for _ in range(width)] for i in range(m): for h in range(n): s = shares[i][h] if T_is_field: s = s.value # type(s) is int or gfpx.Polynomial for r in range(width): sums[r][h] += s * vector[r][i] for r in range(width): for h in range(n): sums[r][h] = field(sums[r][h]) if isinstance(x_rs, tuple): return sums[0] return sums @functools.lru_cache(maxsize=None) def _f_S_i(field, m, i, S): """Compute and store polynomial f_S evaluated for party i. Polynomial f_S is 1 at 0 and 0 for all parties j outside S.""" points = [(0, [1])] + [(x+1, [0]) for x in range(m) if x not in S] return recombine(field, points, i+1)[0].value def pseudorandom_share(field, m, i, prfs, uci, n): """Return pseudorandom Shamir shares for party i for n random numbers. The shares are based on the pseudorandom functions for party i, given in prfs, which maps subsets of parties to PRF instances. Input uci is used to evaluate the PRFs on a unique common input. """ sums = [0] * n # iterate over (m-1 choose t) subsets for degree t. for S, prf_S in prfs.items(): f_S_i = _f_S_i(field, m, i, S) prl = prf_S(uci, n) for h in range(n): sums[h] += prl[h] * f_S_i for h in range(n): sums[h] = field(sums[h]) return sums def pseudorandom_share_zero(field, m, i, prfs, uci, n): """Return pseudorandom Shamir shares for party i for n sharings of 0. The shares are based on the pseudorandom functions for party i, given in prfs, which maps subsets of parties to PRF instances. Input uci is used to evaluate the PRFs on a unique common input. """ T = type(field.modulus) # T is int or T is gfpx.Polynomial sums = [0] * n # iterate over (m-1 choose t) subsets for degree t. for S, prf_S in prfs.items(): f_S_i = _f_S_i(field, m, i, S) d = m - len(S) prl = prf_S(uci, n * d) for h in range(n): y = 0 if T is int else T(0) for j in range(d): y += prl[h * d + j] y *= i+1 sums[h] += y * f_S_i for h in range(n): sums[h] = field(sums[h]) return sums class PRF: """A pseudorandom function (PRF). A PRF is determined by a secret key and a public maximum. """ def __init__(self, key, bound): """Create a PRF determined by the given key and (upper) bound. The key is given as a byte string. Output values will be in range(bound). """ self.key = key self.max = bound self.byte_length = ((bound - 1).bit_length() + 7) // 8 if bound & (bound - 1): # no power of 2 self.byte_length += len(self.key) def __call__(self, s, n=None): """Return a number or length-n list of numbers in range(self.max) for input bytes s.""" if n == 0: return [] n_ = 1 if n is None else n l = self.byte_length if not l: x = [0] * n_ else: dk = hashlib.pbkdf2_hmac('sha1', self.key, s, 1, n_ * l) byteorder = 'little' from_bytes = int.from_bytes # cache bound = self.max x = [from_bytes(dk[i:i + l], byteorder) % bound for i in range(0, n_ * l, l)] return x[0] if n is None else x