import random def _precompute_gf256_exp_log(): exp = [0 for i in range(255)] log = [0 for i in range(256)] poly = 1 for i in range(255): exp[i] = poly log[poly] = i # Multiply poly by the polynomial x + 1. poly = (poly << 1) ^ poly # Reduce poly by x^8 + x^4 + x^3 + x + 1. if poly & 0x100: poly ^= 0x11B return exp, log EXP_TABLE, LOG_TABLE = _precompute_gf256_exp_log() def _gf256_mul(a, b): if a == 0 or b == 0: return 0 return EXP_TABLE[ (LOG_TABLE[a] + LOG_TABLE[b]) % 255 ] def _gf256_pow(a, b): if b == 0: return 1 if a == 0: return 0 c = a for i in range(b - 1): c = _gf256_mul(c,a) return c def _gf256_add(a, b): return a ^ b def _gf256_sub(a, b): return a ^ b def _gf256_inverse(a): if a == 0: raise ZeroDivisionError() return EXP_TABLE[ (-LOG_TABLE[a]) % 255 ] def _gf256_div(a, b): if b == 0: raise ZeroDivisionError() if a == 0: return 0 r = _gf256_mul(a, _gf256_inverse(b)) assert a == _gf256_mul(r, b) return r def _fn(x, q): r = 0 for i, a in enumerate(q): r = _gf256_add(r, _gf256_mul(a,_gf256_pow(x,i))) return r def _interpolation(points, x=0): k = len(points) if k < 2: raise Exception("Minimum 2 points required") points = sorted(points, key=lambda z: z[0]) if len(set(z[0] for z in points)) != k: raise Exception("Unique points required") p_x = 0 for j in range(k): p_j_x = 1 for m in range(k): if m == j: continue a = _gf256_sub(x, points[m][0]) b = _gf256_sub(points[j][0], points[m][0]) c = _gf256_div(a, b) p_j_x = _gf256_mul(p_j_x, c) p_j_x = _gf256_mul( points[j][1], p_j_x) p_x = _gf256_add(p_x , p_j_x) return p_x def split_secret(threshold, total, secret): if not isinstance(secret, bytes): raise TypeError("Secret as byte string required") if threshold > 255: raise ValueError("threshold <= 255") if total > 255: raise ValueError("total shares <= 255") shares = dict() for i in range(total): shares[i+1] = b"" for b in secret: q = [b] for i in range(threshold - 1): a = random.SystemRandom().randint(0, 255) q.append(a) for x in range(total): shares[x+1] += bytes([_fn(x + 1, q)]) return shares def restore_secret(shares): secret = b"" share_length = None for share in shares.values(): if share_length is None: share_length = len(share) if share_length != len(share) or share_length == 0: raise Exception("Invalid shares") for i in range(share_length): secret += bytes([_interpolation([(z, shares[z][i]) for z in shares])]) return secret