# -*- coding:utf-8 -*- """ Description: ECC Curve Usage: from neocore.Cryptography.ECCurve import ECCurve """ import random import binascii from mpmath.libmp import bitcount as _bitlength from logzero import logger modpow = pow # (gcd,c,d)= GCD(a, b) ===> a*c+b*d!=gcd: def GCD(a, b): if (a == 0): return (b, 0, 1) d1, x1, y1 = GCD(b % a, a) return (d1, y1 - (b // a) * x1, x1) def modinv(x, m): (gcd, c, d) = GCD(x, m) return c def samefield(a, b): """ determine if a uses the same field """ if a.field != b.field: return False return True def test_bit(num, index): if (num & (1 << index)): return True return False def randbytes(n): for i in range(0, n): yield random.getrandbits(8) def next_random_integer(size_in_bits): if size_in_bits < 0: raise Exception('size in bits must be greater than zero') if size_in_bits == 0: return 0 balen = int(size_in_bits / 8) + 1 ba = bytearray(randbytes(balen)) if size_in_bits % 8 == 0: ba[balen - 1] = 0 else: ba[balen - 1] &= (1 << size_in_bits % 8) - 1 return int.from_bytes(ba, 'big') def _lucas_sequence(n, P, Q, k): """Return the modular Lucas sequence (U_k, V_k, Q_k). Given a Lucas sequence defined by P, Q, returns the kth values for U and V, along with Q^k, all modulo n. """ D = P * P - 4 * Q if n < 2: raise ValueError("n must be >= 2") if k < 0: raise ValueError("k must be >= 0") if D == 0: raise ValueError("D must not be zero") if k == 0: return 0, 2 U = 1 V = P Qk = Q b = _bitlength(k) if Q == 1: # For strong tests while b > 1: U = (U * V) % n V = (V * V - 2) % n b -= 1 if (k >> (b - 1)) & 1: t = U * D U = U * P + V if U & 1: U += n U >>= 1 V = V * P + t if V & 1: V += n V >>= 1 elif P == 1 and Q == -1: # For Selfridge parameters while b > 1: U = (U * V) % n if Qk == 1: V = (V * V - 2) % n else: V = (V * V + 2) % n Qk = 1 b -= 1 if (k >> (b - 1)) & 1: t = U * D U = U + V if U & 1: U += n U >>= 1 V = V + t if V & 1: V += n V >>= 1 Qk = -1 else: # The general case with any P and Q while b > 1: U = (U * V) % n V = (V * V - 2 * Qk) % n Qk *= Qk b -= 1 if (k >> (b - 1)) & 1: t = U * D U = U * P + V if U & 1: U += n U >>= 1 V = V * P + t if V & 1: V += n V >>= 1 Qk *= Q Qk %= n U %= n V %= n return U, V def sqrtCQ(val, CQ): if test_bit(CQ, 1): z = modpow(val, (CQ >> 2) + 1, CQ) zsquare = (z * z) % CQ if (z * z) % CQ == val: return z else: return None qMinusOne = CQ - 1 legendreExponent = qMinusOne >> 1 if modpow(val, legendreExponent, CQ) != 1: logger.error("legendaire exponent error") return None u = qMinusOne >> 2 k = (u << 1) + 1 Q = val fourQ = (Q << 2) % CQ U = None V = None while U == 1 or U == qMinusOne: P = next_random_integer(CQ.bit_length()) while P >= CQ or modpow(P * P - fourQ, legendreExponent, CQ) != qMinusOne: P = next_random_integer(CQ.bit_length()) U, V = _lucas_sequence(CQ, P, Q, k) if (V * V) % CQ == fourQ: if test_bit(V, 0): V += CQ V >>= 1 assert (V * V) % CQ == val return V return None class FiniteField: """ FiniteField implements a value modulus a number. """ class Value: """ represent a value in the FiniteField this class forwards all operations to the FiniteField class """ def __init__(self, field, value): self.field = field self.value = field.integer(value) # Value * int def __add__(self, rhs): return self.field.add(self, self.field.value(rhs)) def __sub__(self, rhs): return self.field.sub(self, self.field.value(rhs)) def __mul__(self, rhs): return self.field.mul(self, self.field.value(rhs)) def __truediv__(self, rhs): return self.field.div(self, self.field.value(rhs)) def __pow__(self, rhs): return self.field.pow(self, rhs) # int * Value def __radd__(self, rhs): return self.field.add(self.field.value(rhs), self) def __rsub__(self, rhs): return self.field.sub(self.field.value(rhs), self) def __rmul__(self, rhs): return self.field.mul(self.field.value(rhs), self) def __rdiv__(self, rhs): return self.field.div(self.field.value(rhs), self) def __rpow__(self, rhs): return self.field.pow(self.field.value(rhs), self) def __eq__(self, rhs): return self.field.eq(self, self.field.value(rhs)) def __ne__(self, rhs): return not (self == rhs) def __str__(self): return "0x%s" % self.value def __neg__(self): return self.field.neg(self) def sqrt(self, flag): return self.field.sqrt(self, flag) def sqrtCQ(self, CQ): return self.field.sqrtCQ(self, CQ) def inverse(self): return self.field.inverse(self) def iszero(self): return self.value == 0 def __init__(self, p): self.p = p """ several basic operators """ def add(self, lhs, rhs): return samefield(lhs, rhs) and self.value((lhs.value + rhs.value) % self.p) def sub(self, lhs, rhs): return samefield(lhs, rhs) and self.value((lhs.value - rhs.value) % self.p) def mul(self, lhs, rhs): return samefield(lhs, rhs) and self.value((lhs.value * rhs.value) % self.p) def div(self, lhs, rhs): return samefield(lhs, rhs) and self.value((lhs.value * rhs.inverse()) % self.p) def pow(self, lhs, rhs): return self.value(pow(int(lhs.value), int(self.integer(rhs)), self.p)) def eq(self, lhs, rhs): return (lhs.value - rhs.value) % self.p == 0 def neg(self, val): return self.value(self.p - val.value) def sqrt(self, val, flag): """ calculate the square root modulus p """ if val.iszero(): return val sw = self.p % 8 if sw == 3 or sw == 7: res = val ** ((self.p + 1) / 4) elif sw == 5: x = val ** ((self.p + 1) / 4) if x == 1: res = val ** ((self.p + 3) / 8) else: res = (4 * val) ** ((self.p - 5) / 8) * 2 * val else: raise Exception("modsqrt non supported for (p%8)==1") if res.value % 2 == flag: return res else: return -res def inverse(self, value): """ calculate the multiplicative inverse """ return modinv(value.value, self.p) def value(self, x): """ converts an integer or FinitField.Value to a value of this FiniteField. """ return x if isinstance(x, FiniteField.Value) and x.field == self else FiniteField.Value(self, x) def integer(self, x): """ returns a plain integer """ if type(x) is str: hex = binascii.unhexlify(x) return int.from_bytes(hex, 'big') return x.value if isinstance(x, FiniteField.Value) else x def zero(self): """ returns the additive identity value meaning: a + 0 = a """ return FiniteField.Value(self, 0) def one(self): """ returns the multiplicative identity value meaning a * 1 = a """ return FiniteField.Value(self, 1) class EllipticCurve: """ EllipticCurve implements a point on a elliptic curve """ class ECPoint: """ represent a value in the EllipticCurve this class forwards all operations to the EllipticCurve class """ def __init__(self, curve, x, y): self.curve = curve self.x = x self.y = y # Point + Point def __add__(self, rhs): return self.curve.add(self, rhs) def __sub__(self, rhs): return self.curve.sub(self, rhs) # Point * int or Point * Value def __mul__(self, rhs): return self.curve.mul(self, rhs) def __truediv__(self, rhs): return self.curve.div(self, rhs) def __eq__(self, rhs): return self.curve.eq(self, rhs) def __ne__(self, rhs): return not (self == rhs) def __lt__(self, other): if other == self: return False elif self.x.value < other.x.value: return True elif self.x.value > other.x.value: return False elif self.x.value == other.x.value: return False return self.y.value < other.y.value def __gt__(self, other): if other == self: return False elif self.x.value > other.x.value: return True elif self.x.value < other.x.value: return False elif self.x.value == other.x.value: return False return self.y.value > other.y.value def __le__(self, other): if other == self: return True return self.__lt__(other) def __ge__(self, other): if other == self: return True return self.__gt__(other) def __str__(self): return "(%s,%s)" % (self.x, self.y) def __neg__(self): return self.curve.neg(self) def iszero(self): return self.x.iszero() and self.y.iszero() def isoncurve(self): return self.curve.isoncurve(self) @property def IsInfinity(self): return True if self == self.curve.Infinity else False def Size(self): if self.IsInfinity: return 1 else: return 33 def encode_point(self, compressed=True, endian='little'): if self.IsInfinity: return bytearray([0]) xbytes = bytearray(self.x.value.to_bytes(32, endian)) xbytes.reverse() if compressed: byteone = b'\x03' if self.y.value % 2 == 0: byteone = b'\x02' data = bytearray(byteone) + xbytes return binascii.hexlify(data) else: ybytes = bytearray(self.y.value.to_bytes(32, endian)) ybytes.reverse() data = bytearray(b'\x04') + xbytes + ybytes return binascii.hexlify(data) def ToString(self): return binascii.hexlify(self.encode_point(compressed=True)).decode('utf-8') def ToBytes(self): return binascii.hexlify(self.encode_point(compressed=True)) def Serialize(self, writer, compress=True): if self == self.curve.Infinity: writer.WriteByte(b'\x00') else: byt = self.encode_point(compressed=compress) writer.WriteBytes(byt) def __init__(self, field, a, b): self.field = field self.a = field.value(a) self.b = field.value(b) @property def Infinity(self): return self.point(0, 0) def add(self, p, q): """ perform elliptic curve addition """ if p.iszero(): return q if q.iszero(): return p lft = 0 # calculate the slope of the intersection line if p == q: if p.y == 0: return self.zero() lft = (3 * p.x ** 2 + self.a) / (2 * p.y) elif p.x == q.x: return self.zero() else: lft = (p.y - q.y) / (p.x - q.x) # calculate the intersection point x = lft ** 2 - (p.x + q.x) y = lft * (p.x - x) - p.y return self.point(x, y) # subtraction is : a - b = a + -b def sub(self, lhs, rhs): return lhs + -rhs # scalar multiplication is implemented like repeated addition def mul(self, pt, scalar): scalar = self.field.integer(scalar) accumulator = self.zero() shifter = pt while scalar != 0: bit = scalar % 2 if bit: accumulator += shifter shifter += shifter scalar /= 2 return accumulator def div(self, pt, scalar): """ scalar division: P / a = P * (1/a) scalar is assumed to be of type FiniteField(grouporder) """ return pt * (1 / scalar) def eq(self, lhs, rhs): return lhs.x == rhs.x and lhs.y == rhs.y def neg(self, pt): return self.point(pt.x, -pt.y) def zero(self): """ Return the additive identity point ( aka '0' ) P + 0 = P """ return self.point(self.field.zero(), self.field.zero()) def point(self, x, y): """ construct a point from 2 values """ return EllipticCurve.ECPoint(self, self.field.value(x), self.field.value(y)) def isoncurve(self, p): """ verifies if a point is on the curve """ return p.iszero() or p.y ** 2 == p.x ** 3 + self.a * p.x + self.b def decompress(self, x, flag): """ calculate the y coordinate given only the x value. there are 2 possible solutions, use 'flag' to select. """ x = self.field.value(x) ysquare = x ** 3 + self.a * x + self.b return self.point(x, ysquare.sqrt(flag)) def decode_from_reader(self, reader): f = reader.ReadByte() if f == 0: return self.Infinity # these are compressed if f == 2 or f == 3: yTilde = f & 1 data = bytearray(reader.ReadBytes(32)) data.reverse() data.append(0) X1 = int.from_bytes(data, 'little') return self.decompress_from_curve(X1, yTilde) # uncompressed or hybrid elif f == 4 or f == 6 or f == 7: raise NotImplementedError() raise Exception("Invalid point incoding: %s " % f) def decode_from_hex(self, hex_str, unhex=True): ba = None if unhex: ba = bytearray(binascii.unhexlify(hex_str)) else: ba = hex_str cq = self.field.p expected_byte_len = int((_bitlength(cq) + 7) / 8) f = ba[0] if f == 0: return self.Infinity # these are compressed if f == 2 or f == 3: if len(ba) != expected_byte_len + 1: raise Exception("Incorrrect length for encoding") yTilde = f & 1 data = bytearray(ba[1:]) data.reverse() data.append(0) X1 = int.from_bytes(data, 'little') return self.decompress_from_curve(X1, yTilde) # uncompressed or hybrid elif f == 4: if len(ba) != (2 * expected_byte_len) + 1: raise Exception("Incorrect length for compressed encoding") x_data = bytearray(ba[1:1 + expected_byte_len]) x_data.reverse() x_data.append(0) y_data = bytearray(ba[1 + expected_byte_len:]) y_data.reverse() y_data.append(0) x = int.from_bytes(x_data, 'little') y = int.from_bytes(y_data, 'little') pnt = self.point(x, y) return pnt elif f == 6 or f == 7: raise NotImplementedError() else: raise Exception("Invalid point incoding: %s " % f) def decompress_from_curve(self, x, flag): """ calculate the y coordinate given only the x value. there are 2 possible solutions, use 'flag' to select. """ cq = self.field.p x = self.field.value(x) ysquare = x ** 3 + self.a * x + self.b ysquare_root = sqrtCQ(ysquare.value, cq) bit0 = 0 if ysquare_root % 2 is not 0: bit0 = 1 if bit0 != flag: beta = (cq - ysquare_root) % cq else: beta = ysquare_root return self.point(x, beta) class ECDSA: """ Digital Signature Algorithm using Elliptic Curves """ def __init__(self, ec, G, n): self.ec = ec self.G = G self.GFn = FiniteField(n) @property def Curve(self): return self.ec def calcpub(self, privkey): """ calculate the public key for private key x return G*x """ return self.G * self.GFn.value(privkey) def sign(self, message, privkey, secret): """ sign the message using private key and sign secret for signsecret k, message m, privatekey x return (G*k, (m+x*r)/k) """ m = self.GFn.value(message) x = self.GFn.value(privkey) k = self.GFn.value(secret) R = self.G * k r = self.GFn.value(R.x) s = (m + x * r) / k return (r, s) def verify(self, message, pubkey, rnum, snum): """ Verify the signature for message m, pubkey Y, signature (r,s) r = xcoord(R) verify that : G*m+Y*r=R*s this is true because: { Y=G*x, and R=G*k, s=(m+x*r)/k } G*m+G*x*r = G*k*(m+x*r)/k -> G*(m+x*r) = G*(m+x*r) several ways to do the verification: r == xcoord[ G*(m/s) + Y*(r/s) ] <<< the standard way R * s == G*m + Y*r r == xcoord[ (G*m + Y*r)/s) ] """ m = self.GFn.value(message) r = self.GFn.value(rnum) s = self.GFn.value(snum) R = self.G * (m / s) + pubkey * (r / s) # alternative methods of verifying # RORG= self.ec.decompress(r, 0) # RR = self.G * m + pubkey * r # print "#1: %s .. %s" % (RR, RORG*s) # print "#2: %s .. %s" % (RR*(1/s), r) # print "#3: %s .. %s" % (R, r) return R.x == r def findpk(self, message, rnum, snum, flag): """ find pubkey Y from message m, signature (r,s) Y = (R*s-G*m)/r note that there are 2 pubkeys related to a signature """ m = self.GFn.value(message) r = self.GFn.value(rnum) s = self.GFn.value(snum) R = self.ec.decompress(r, flag) # return (R*s - self.G * m)*(1/r) return R * (s / r) - self.G * (m / r) def findpk2(self, r1, s1, r2, s2, flag1, flag2): """ find pubkey Y from 2 different signature on the same message sigs: (r1,s1) and (r2,s2) returns (R1*s1-R2*s2)/(r1-r2) """ R1 = self.ec.decompress(r1, flag1) R2 = self.ec.decompress(r2, flag2) rdiff = self.GFn.value(r1 - r2) return (R1 * s1 - R2 * s2) * (1 / rdiff) def crack2(self, r, s1, s2, m1, m2): """ find signsecret and privkey from duplicate 'r' signature (r,s1) for message m1 and signature (r,s2) for message m2 s1= (m1 + x*r)/k s2= (m2 + x*r)/k subtract -> (s1-s2) = (m1-m2)/k -> k = (m1-m2)/(s1-s2) -> privkey = (s1*k-m1)/r .. or (s2*k-m2)/r """ sdelta = self.GFn.value(s1 - s2) mdelta = self.GFn.value(m1 - m2) secret = mdelta / sdelta x1 = self.crack1(r, s1, m1, secret) x2 = self.crack1(r, s2, m2, secret) if x1 != x2: logger.info("x1= %s" % x1) logger.info("x2= %s" % x2) return (secret, x1) def crack1(self, rnum, snum, message, signsecret): """ find privkey, given signsecret k, message m, signature (r,s) x= (s*k-m)/r """ m = self.GFn.value(message) r = self.GFn.value(rnum) s = self.GFn.value(snum) k = self.GFn.value(signsecret) return (s * k - m) / r @staticmethod def secp256r1(): """ create the secp256r1 curve """ GFp = FiniteField(int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16)) ec = EllipticCurve(GFp, 115792089210356248762697446949407573530086143415290314195533631308867097853948, 41058363725152142129326129780047268409114441015993725554835256314039467401291) # return ECDSA(GFp, ec.point(0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296,0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5),int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC", 16)) return ECDSA(ec, ec.point(0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296, 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5), GFp) @staticmethod def decode_secp256r1(str, unhex=True, check_on_curve=True): """ decode a public key on the secp256r1 curve """ GFp = FiniteField(int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16)) ec = EllipticCurve(GFp, 115792089210356248762697446949407573530086143415290314195533631308867097853948, 41058363725152142129326129780047268409114441015993725554835256314039467401291) point = ec.decode_from_hex(str, unhex=unhex) if check_on_curve: if point.isoncurve(): return ECDSA(GFp, point, int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC", 16)) else: raise Exception("Could not decode string") return ECDSA(GFp, point, int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC", 16)) @staticmethod def Deserialize_Secp256r1(reader): GFp = FiniteField(int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16)) ec = EllipticCurve(GFp, 115792089210356248762697446949407573530086143415290314195533631308867097853948, 41058363725152142129326129780047268409114441015993725554835256314039467401291) return ec.decode_from_reader(reader) @staticmethod def FromBytes_Secp256r1(pubkey): length = len(pubkey) if length == 33 or length == 65: return ECDSA.decode_secp256r1(pubkey) elif length == 64 or length == 72: skip = length - 64 out = bytearray(b'04').hex() + pubkey[skip:] return ECDSA.decode_secp256r1(out) elif length == 96 or length == 104: skip = length - 96 out = bytearray(b'\x04') + bytearray(pubkey[skip:skip + 64]) return ECDSA.decode_secp256r1(out, unhex=False, check_on_curve=False) @staticmethod def secp256k1(): """ create the secp256k1 curve """ GFp = FiniteField(2 ** 256 - 2 ** 32 - 977) # This is P from below... aka FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F ec = EllipticCurve(GFp, 0, 7) return ECDSA(ec, ec.point(0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8), 2 ** 256 - 432420386565659656852420866394968145599) @staticmethod def SignSecp256R1(message, prikey, pubkey): GFp = FiniteField(int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16)) ec = EllipticCurve(GFp, 115792089210356248762697446949407573530086143415290314195533631308867097853948, 41058363725152142129326129780047268409114441015993725554835256314039467401291) edcsa = ECDSA(ec, ec.point(pubkey.x.value, pubkey.y.value), GFp) res = edcsa.sign(message, prikey) return res