package com.oraclechain.eosio.crypto.ec;


import java.math.BigInteger;
import java.util.Random;

/**
 * This class represents an elliptic field element.
 */
public class EcFieldElement {

    private static final BigInteger TWO = BigInteger.valueOf(2);

    private BigInteger _x;
    private BigInteger _q;

    public EcFieldElement(BigInteger q, BigInteger x) {
        this._x = x;
        if (x.compareTo(q) >= 0) {
            throw new IllegalArgumentException("x value too large in field element");
        }
        this._q = q;
    }

    public BigInteger toBigInteger() {
        return _x;
    }

    public int getFieldSize() {
        return _q.bitLength();
    }

    public BigInteger getQ() {
        return _q;
    }

    public EcFieldElement add(EcFieldElement b) {
        return new EcFieldElement(_q, _x.add(b.toBigInteger()).mod(_q));
    }

    public EcFieldElement subtract(EcFieldElement b) {
        return new EcFieldElement(_q, _x.subtract(b.toBigInteger()).mod(_q));
    }

    public EcFieldElement multiply(EcFieldElement b) {
        return new EcFieldElement(_q, _x.multiply(b.toBigInteger()).mod(_q));
    }

    // TODO: optimize. much time is spent here. (13% total)
    public EcFieldElement divide(EcFieldElement b) {
        return new EcFieldElement(_q, _x.multiply(b.toBigInteger().modInverse(_q)).mod(_q));
    }

    public EcFieldElement negate() {
        return new EcFieldElement(_q, _x.negate().mod(_q));
    }

    public EcFieldElement square() {
        return new EcFieldElement(_q, _x.multiply(_x).mod(_q));
    }

    public EcFieldElement invert() {
        return new EcFieldElement(_q, _x.modInverse(_q));
    }

    @Override
    public String toString() {
        return this.toBigInteger().toString(16);
    }

    public EcFieldElement sqrt() {
        if (!_q.testBit(0)) {
            throw new RuntimeException("not done yet");
        }

        if (_q.testBit(1)) {
            // z = g^(u+1) + p, p = 4u + 3
            EcFieldElement z = new EcFieldElement(_q, _x.modPow(_q.shiftRight(2).add(BigInteger.ONE), _q));

            return z.square().equals(this) ? z : null;
        }

        // p mod 4 == 1
        BigInteger qMinusOne = _q.subtract(BigInteger.ONE);

        BigInteger legendreExponent = qMinusOne.shiftRight(1);
        if (!(_x.modPow(legendreExponent, _q).equals(BigInteger.ONE))) {
            return null;
        }

        BigInteger u = qMinusOne.shiftRight(2);
        BigInteger k = u.shiftLeft(1).add(BigInteger.ONE);

        BigInteger Q = this._x;
        BigInteger fourQ = Q.shiftLeft(2).mod(_q);

        BigInteger U, V;
        Random rand = new Random();
        do {
            BigInteger P;
            do {
                P = new BigInteger(_q.bitLength(), rand);
            } while (P.compareTo(_q) >= 0
                    || !(P.multiply(P).subtract(fourQ).modPow(legendreExponent, _q).equals(qMinusOne)));

            BigInteger[] result = lucasSequence(_q, P, Q, k);
            U = result[0];
            V = result[1];

            if (V.multiply(V).mod(_q).equals(fourQ)) {
                // Integer division by 2, mod q
                if (V.testBit(0)) {
                    V = V.add(_q);
                }

                V = V.shiftRight(1);

                return new EcFieldElement(_q, V);
            }
        } while (U.equals(BigInteger.ONE) || U.equals(qMinusOne));

        return null;

    }

    private static BigInteger[] lucasSequence(BigInteger p, BigInteger P, BigInteger Q, BigInteger k) {
        int n = k.bitLength();
        int s = k.getLowestSetBit();

        BigInteger Uh = BigInteger.ONE;
        BigInteger Vl = TWO;
        BigInteger Vh = P;
        BigInteger Ql = BigInteger.ONE;
        BigInteger Qh = BigInteger.ONE;

        for (int j = n - 1; j >= s + 1; --j) {
            Ql = Ql.multiply(Qh).mod(p);

            if (k.testBit(j)) {
                Qh = Ql.multiply(Q).mod(p);
                Uh = Uh.multiply(Vh).mod(p);
                Vl = Vh.multiply(Vl).subtract(P.multiply(Ql)).mod(p);
                Vh = Vh.multiply(Vh).subtract(Qh.shiftLeft(1)).mod(p);
            } else {
                Qh = Ql;
                Uh = Uh.multiply(Vl).subtract(Ql).mod(p);
                Vh = Vh.multiply(Vl).subtract(P.multiply(Ql)).mod(p);
                Vl = Vl.multiply(Vl).subtract(Ql.shiftLeft(1)).mod(p);
            }
        }

        Ql = Ql.multiply(Qh).mod(p);
        Qh = Ql.multiply(Q).mod(p);
        Uh = Uh.multiply(Vl).subtract(Ql).mod(p);
        Vl = Vh.multiply(Vl).subtract(P.multiply(Ql)).mod(p);
        Ql = Ql.multiply(Qh).mod(p);

        for (int j = 1; j <= s; ++j) {
            Uh = Uh.multiply(Vl).mod(p);
            Vl = Vl.multiply(Vl).subtract(Ql.shiftLeft(1)).mod(p);
            Ql = Ql.multiply(Ql).mod(p);
        }

        return new BigInteger[] { Uh, Vl };
    }

    public boolean equals(Object other) {
        if (other == this) {
            return true;
        }

        if (!(other instanceof EcFieldElement)) {
            return false;
        }

        EcFieldElement o = (EcFieldElement) other;
        return _q.equals(o._q) && _x.equals(o._x);
    }

    public int hashCode() {
        return _q.hashCode() ^ _x.hashCode();
    }
}