/* * Copyright 2013 Square Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.squareup.jnagmp; import com.squareup.jnagmp.LibGmp.mpz_t; import com.sun.jna.Memory; import com.sun.jna.Native; import com.sun.jna.NativeLong; import com.sun.jna.Pointer; import java.math.BigInteger; import static com.squareup.jnagmp.LibGmp.__gmpz_clear; import static com.squareup.jnagmp.LibGmp.__gmpz_cmp_si; import static com.squareup.jnagmp.LibGmp.__gmpz_divexact; import static com.squareup.jnagmp.LibGmp.__gmpz_export; import static com.squareup.jnagmp.LibGmp.__gmpz_gcd; import static com.squareup.jnagmp.LibGmp.__gmpz_import; import static com.squareup.jnagmp.LibGmp.__gmpz_init; import static com.squareup.jnagmp.LibGmp.__gmpz_invert; import static com.squareup.jnagmp.LibGmp.__gmpz_jacobi; import static com.squareup.jnagmp.LibGmp.__gmpz_neg; import static com.squareup.jnagmp.LibGmp.__gmpz_powm; import static com.squareup.jnagmp.LibGmp.__gmpz_powm_sec; import static com.squareup.jnagmp.LibGmp.readSizeT; import static java.lang.Math.max; import static java.lang.Math.min; /** High level Java API for accessing {@link LibGmp} safely. */ public final class Gmp { private static final UnsatisfiedLinkError LOAD_ERROR; static { UnsatisfiedLinkError localLoadError = null; try { LibGmp.init(); } catch (UnsatisfiedLinkError e) { localLoadError = e; } LOAD_ERROR = localLoadError; } /** * Verifies this library is loaded properly. * * @throws UnsatisfiedLinkError if the library failed to load properly. */ public static void checkLoaded() { if (LOAD_ERROR != null) { throw LOAD_ERROR; } // Make a test call, sometimes the error won't occur until you try the native method. // 2 ^ 3 = 8, 8 mod 5 = 3 BigInteger two = BigInteger.valueOf(2); BigInteger three = BigInteger.valueOf(3); BigInteger four = BigInteger.valueOf(4); BigInteger five = BigInteger.valueOf(5); BigInteger answer; answer = modPowInsecure(two, three, five); if (!three.equals(answer)) { throw new AssertionError("libgmp is loaded but modPowInsecure returned the wrong answer"); } answer = modPowSecure(two, three, five); if (!three.equals(answer)) { throw new AssertionError("libgmp is loaded but modPowSecure returned the wrong answer"); } int answr = kronecker(four, five); if (answr != 1) { throw new AssertionError("libgmp is loaded but kronecker returned the wrong answer"); } } /** * Calculate kronecker symbol a|p. Generalization of legendre and jacobi. * * @param a an integer * @param p the modulus * @return a|p */ public static int kronecker(BigInteger a, BigInteger p) { return INSTANCE.get().kroneckerImpl(a, p); } /** * Calculate (base ^ exponent) % modulus; faster, VULNERABLE TO TIMING ATTACKS. * * @param base the base, must be positive * @param exponent the exponent * @param modulus the modulus * @return the (base ^ exponent) % modulus * @throws ArithmeticException if modulus is non-positive, or the exponent is negative and the * base cannot be inverted */ public static BigInteger modPowInsecure(BigInteger base, BigInteger exponent, BigInteger modulus) { if (modulus.signum() <= 0) { throw new ArithmeticException("modulus must be positive"); } return INSTANCE.get().modPowInsecureImpl(base, exponent, modulus); } /** * Calculate (base ^ exponent) % modulus; slower, hardened against timing attacks. * * <p> NOTE: this methods REQUIRES modulus to be odd, due to a crash-bug in libgmp. This is not a * problem for RSA where the modulus is always odd.</p> * * @param base the base, must be positive * @param exponent the exponent * @param modulus the modulus * @return the (base ^ exponent) % modulus * @throws ArithmeticException if modulus is non-positive, or the exponent is negative and the * base cannot be inverted * @throws IllegalArgumentException if modulus is even */ public static BigInteger modPowSecure(BigInteger base, BigInteger exponent, BigInteger modulus) { if (modulus.signum() <= 0) { throw new ArithmeticException("modulus must be positive"); } if (!modulus.testBit(0)) { throw new IllegalArgumentException("modulus must be odd"); } return INSTANCE.get().modPowSecureImpl(base, exponent, modulus); } /** * Calculate val^-1 % modulus. * * @param val must be positive * @param modulus the modulus * @return val^-1 % modulus * @throws ArithmeticException if modulus is non-positive or val is not invertible */ public static BigInteger modInverse(BigInteger val, BigInteger modulus) { if (modulus.signum() <= 0) { throw new ArithmeticException("modulus must be positive"); } return INSTANCE.get().modInverseImpl(val, modulus); } /** * Divide dividend by divisor. This method only returns correct answers when the division produces * no remainder. Correct answers should not be expected when the divison would result in a * remainder. * * @return dividend / divisor * @throws ArithmeticException if divisor is zero */ public static BigInteger exactDivide(BigInteger dividend, BigInteger divisor) { if (divisor.signum() == 0) { throw new ArithmeticException("BigInteger divide by zero"); } return INSTANCE.get().exactDivImpl(dividend, divisor); } /** * Return the greatest common divisor of value1 and value2. The result is always positive even if * one or both input operands are negative. Except if both inputs are zero; then this method * defines gcd(0,0) = 0. * * @return greatest common divisor of value1 and value2 */ public static BigInteger gcd(BigInteger value1, BigInteger value2) { return INSTANCE.get().gcdImpl(value1, value2); } /** * VISIBLE FOR TESTING. Reuse the same buffers over and over to minimize allocations and native * boundary crossings. */ static final ThreadLocal<Gmp> INSTANCE = new ThreadLocal<Gmp>() { @Override protected Gmp initialValue() { return new Gmp(); } }; /** Initial bit size of the scratch buffer. */ private static final int INITIAL_BUF_BITS = 2048; private static final int INITIAL_BUF_SIZE = INITIAL_BUF_BITS / 8; /** Maximum number of operands we need for any operation. */ private static final int MAX_OPERANDS = 4; private static final int SHARED_MEM_SIZE = mpz_t.SIZE * MAX_OPERANDS + Native.SIZE_T_SIZE; /** * Operands that can be reused over and over to avoid costly initialization and tear down. Backed * by {@link #sharedMem}. */ private final mpz_t[] sharedOperands = new mpz_t[MAX_OPERANDS]; /** The out size_t pointer for export. Backed by {@link #sharedMem}. */ private final Pointer countPtr; /** A fixed, shared, reusable memory buffer. */ private final Memory sharedMem = new Memory(SHARED_MEM_SIZE) { /** Must explicitly destroy the gmp_t structs before freeing the underlying memory. */ @Override protected void finalize() { for (mpz_t sharedOperand : sharedOperands) { if (sharedOperand != null) { __gmpz_clear(sharedOperand); } } super.finalize(); } }; /** Reusable scratch buffer for moving data between byte[] and mpz_t. */ private Memory scratchBuf = new Memory(INITIAL_BUF_SIZE); private Gmp() { int offset = 0; for (int i = 0; i < MAX_OPERANDS; ++i) { this.sharedOperands[i] = new mpz_t(sharedMem.share(offset, mpz_t.SIZE)); __gmpz_init(sharedOperands[i]); offset += mpz_t.SIZE; } this.countPtr = sharedMem.share(offset, Native.SIZE_T_SIZE); offset += Native.SIZE_T_SIZE; assert offset == SHARED_MEM_SIZE; } private int kroneckerImpl(BigInteger a, BigInteger p) { mpz_t aPeer = getPeer(a, sharedOperands[0]); mpz_t pPeer = getPeer(p, sharedOperands[1]); return __gmpz_jacobi(aPeer, pPeer); } private BigInteger modPowInsecureImpl(BigInteger base, BigInteger exp, BigInteger mod) { boolean invert = exp.signum() < 0; if (invert) { exp = exp.negate(); } mpz_t basePeer = getPeer(base, sharedOperands[0]); mpz_t expPeer = getPeer(exp, sharedOperands[1]); mpz_t modPeer = getPeer(mod, sharedOperands[2]); if (invert) { int res = __gmpz_invert(basePeer, basePeer, modPeer); if (res == 0) { throw new ArithmeticException("val not invertible"); } } __gmpz_powm(sharedOperands[3], basePeer, expPeer, modPeer); // The result size should be <= modulus size, but round up to the nearest byte. int requiredSize = (mod.bitLength() + 7) / 8; return new BigInteger(mpzSgn(sharedOperands[3]), mpzExport(sharedOperands[3], requiredSize)); } private BigInteger modPowSecureImpl(BigInteger base, BigInteger exp, BigInteger mod) { boolean invert = exp.signum() < 0; if (invert) { exp = exp.negate(); } mpz_t basePeer = getPeer(base, sharedOperands[0]); mpz_t expPeer = getPeer(exp, sharedOperands[1]); mpz_t modPeer = getPeer(mod, sharedOperands[2]); if (invert) { int res = __gmpz_invert(basePeer, basePeer, modPeer); if (res == 0) { throw new ArithmeticException("val not invertible"); } } __gmpz_powm_sec(sharedOperands[3], basePeer, expPeer, modPeer); // The result size should be <= modulus size, but round up to the nearest byte. int requiredSize = (mod.bitLength() + 7) / 8; return new BigInteger(mpzSgn(sharedOperands[3]), mpzExport(sharedOperands[3], requiredSize)); } private BigInteger modInverseImpl(BigInteger val, BigInteger mod) { mpz_t valPeer = getPeer(val, sharedOperands[0]); mpz_t modPeer = getPeer(mod, sharedOperands[1]); int res = __gmpz_invert(sharedOperands[2], valPeer, modPeer); if (res == 0) { throw new ArithmeticException("val not invertible"); } // The result size should be <= modulus size, but round up to the nearest byte. int requiredSize = (mod.bitLength() + 7) / 8; return new BigInteger(mpzSgn(sharedOperands[2]), mpzExport(sharedOperands[2], requiredSize)); } private BigInteger exactDivImpl(BigInteger dividend, BigInteger divisor) { mpz_t dividendPeer = getPeer(dividend, sharedOperands[0]); mpz_t divisorPeer = getPeer(divisor, sharedOperands[1]); __gmpz_divexact(sharedOperands[2], dividendPeer, divisorPeer); // The result size is never larger than the bit length of the dividend minus that of the divisor // plus 1 (but is at least 1 bit long to hold the case that the two values are exactly equal) int requiredSize = max(dividend.bitLength() - divisor.bitLength() + 1, 1); return new BigInteger(mpzSgn(sharedOperands[2]), mpzExport(sharedOperands[2], requiredSize)); } private BigInteger gcdImpl(BigInteger value1, BigInteger value2) { mpz_t value1Peer = getPeer(value1, sharedOperands[0]); mpz_t value2Peer = getPeer(value2, sharedOperands[1]); __gmpz_gcd(sharedOperands[2], value1Peer, value2Peer); // The result size will be no larger than the smaller of the inputs int requiredSize = min(value1.bitLength(), value2.bitLength()); return new BigInteger(mpzSgn(sharedOperands[2]), mpzExport(sharedOperands[2], requiredSize)); } /** * If {@code value} is a {@link GmpInteger}, return its peer. Otherwise, import {@code value} into * {@code sharedPeer} and return {@code sharedPeer}. */ private mpz_t getPeer(BigInteger value, mpz_t sharedPeer) { if (value instanceof GmpInteger) { return ((GmpInteger) value).getPeer(); } mpzImport(sharedPeer, value.signum(), value.abs().toByteArray()); return sharedPeer; } void mpzImport(mpz_t ptr, int signum, byte[] bytes) { int expectedLength = bytes.length; ensureBufferSize(expectedLength); scratchBuf.write(0, bytes, 0, bytes.length); __gmpz_import(ptr, bytes.length, 1, 1, 1, 0, scratchBuf); if (signum < 0) { __gmpz_neg(ptr, ptr); } } private byte[] mpzExport(mpz_t ptr, int requiredSize) { ensureBufferSize(requiredSize); __gmpz_export(scratchBuf, countPtr, 1, 1, 1, 0, ptr); int count = readSizeT(countPtr); byte[] result = new byte[count]; scratchBuf.read(0, result, 0, count); return result; } private static final NativeLong ZERO = new NativeLong(); int mpzSgn(mpz_t ptr) { int result = __gmpz_cmp_si(ptr, ZERO); if (result < 0) { return -1; } else if (result > 0) { return 1; } return 0; } private void ensureBufferSize(int size) { if (scratchBuf.size() < size) { long newSize = scratchBuf.size(); while (newSize < size) { newSize <<= 1; } scratchBuf = new Memory(newSize); } } }