package at.favre.lib.armadillo;

import android.os.StrictMode;
import androidx.annotation.NonNull;

import org.junit.Test;
import org.mindrot.jbcrypt.BCrypt;

import at.favre.lib.bytes.Bytes;
import at.favre.lib.crypto.HKDF;

import static org.junit.Assert.assertArrayEquals;

/**
 * Test to check if the re-implementation of the old broken bcrypt key stretcher behaves exactly
 * the same as with the new one.
 */
public class BrokenBcryptKeyStretcherTest {

    @Test
    public void testBcrypt() {
        for (int i = 4; i < 10; i++) {
            compareLegacyFallbackImpl(i, 8, 8);
            compareLegacyFallbackImpl(i, 16, 16);
            compareLegacyFallbackImpl(i, 32, 16);
            compareLegacyFallbackImpl(i, 32, 32);
            compareLegacyFallbackImpl(i, 64, 32);
            compareLegacyFallbackImpl(i, 64, 64);
            compareLegacyFallbackImpl(i, 128, 128);
            compareLegacyFallbackImpl(i, 4, 128);
            compareLegacyFallbackImpl(i, 128, 4);
            compareLegacyFallbackImpl(i, 64, 16);
        }
    }

    private void compareLegacyFallbackImpl(int cost, int saltLength, int pwByteLength) {
        byte[] salt = Bytes.random(saltLength).array();
        char[] pw = Bytes.random(pwByteLength).encodeBase64().toCharArray();
        KeyStretchingFunction b1 = new LegacyBrokenJBcryptKeyStretcher(cost);
        KeyStretchingFunction b2 = new BrokenBcryptKeyStretcher(cost);

        byte[] out1 = b1.stretch(salt, pw, 16);
        byte[] out2 = b2.stretch(salt, pw, 16);

        assertArrayEquals(String.format("Hashes do not match\n%s\nvs\n%s", Bytes.wrap(out1).encodeHex(), Bytes.wrap(out2).encodeHex()), out1, out2);
    }

    /**
     * This is the old broken implementation of Bcrypt key stretcher as reference
     *
     * @deprecated this is only for testing purpose
     */
    @SuppressWarnings("DeprecatedIsStillUsed")
    @Deprecated
    static final class LegacyBrokenJBcryptKeyStretcher implements KeyStretchingFunction {
        private static final int BCRYPT_MIN_ROUNDS = 8;
        private final int iterations;

        /**
         * Creates a new instance with desired rounds.
         *
         * @param log2Rounds this is the log2(Iterations). e.g. 12 ==> 2^12 = 4,096 iterations, the higher, the slower
         *                   cannot be smaller than 8
         */
        LegacyBrokenJBcryptKeyStretcher(int log2Rounds) {
            this.iterations = Math.max(BCRYPT_MIN_ROUNDS, log2Rounds);
        }

        @Override
        public byte[] stretch(byte[] salt, char[] password, int outLengthByte) {
            try {
                return HKDF.fromHmacSha256().expand(bcrypt(password, salt, iterations), "bcrypt".getBytes(), outLengthByte);
            } catch (Exception e) {
                throw new IllegalStateException("could not stretch with bcrypt", e);
            }
        }

        /**
         * Computes the Bcrypt hash of a password.
         *
         * @param password  the password to hash.
         * @param salt      the salt
         * @param logRounds log2(Iterations). e.g. 12 ==> 2^12 = 4,096 iterations
         * @return the Bcrypt hash of the password
         */
        private static byte[] bcrypt(char[] password, byte[] salt, int logRounds) {
            StrictMode.noteSlowCall("bcrypt is a very expensive call and should not be done on the main thread");
            return Bytes.from(BCrypt.hashpw(String.valueOf(password) + Bytes.wrap(salt).encodeHex(), generateSalt(salt, logRounds))).array();
        }

        @NonNull
        private static String generateSalt(byte[] salt, int logRounds) {
            StringBuilder saltBuilder = new StringBuilder();
            saltBuilder.append("$2a$");
            if (logRounds < 10) {
                saltBuilder.append("0");
            }
            if (logRounds > 30) {
                throw new IllegalArgumentException("log_rounds exceeds maximum (30)");
            }
            saltBuilder.append(Integer.toString(logRounds));
            saltBuilder.append("$");
            saltBuilder.append(Bytes.wrap(HKDF.fromHmacSha256().expand(salt, "bcrypt".getBytes(), 16)).encodeHex());
            return saltBuilder.toString();
        }
    }
}