/*
 * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
 * in compliance with the License. A copy of the License is located at
 * 
 * http://aws.amazon.com/apache2.0
 * 
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */

package com.amazonaws.encryptionsdk.internal;

import java.io.Serializable;
import java.math.BigInteger;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Comparator;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.commons.lang3.ArrayUtils;
import org.bouncycastle.util.encoders.Base64;

/**
 * Internal utility methods.
 */
public final class Utils {
    // SecureRandom objects can both be expensive to initialize and incur synchronization costs.
    // This allows us to minimize both initializations and keep SecureRandom usage thread local
    // to avoid lock contention.
    private static final ThreadLocal<SecureRandom> LOCAL_RANDOM = new ThreadLocal<SecureRandom>() {
      @Override
      protected SecureRandom initialValue() {
          final SecureRandom rnd = new SecureRandom();
          rnd.nextBoolean(); // Force seeding
          return rnd;
      }
    };

    private Utils() {
        // Prevent instantiation
    }

    /*
     * In some areas we need to be able to assign a total order over Java objects - generally with some primary sort,
     * but we need a fallback sort that always works in order to ensure that we don't falsely claim objects A and B
     * are equal just because the primary sort declares them to have equal rank.
     *
     * To do this, we'll define a fallback sort that assigns an arbitrary order to all objects. This order is
     * implemented by first comparing hashcode, and in the rare case where we are asked to compare two objects with
     * equal hashcode, we explicitly assign an index to them - using a WeakHashMap to track this index - and sort
     * based on this index.
     */
    private static AtomicLong FALLBACK_COUNTER = new AtomicLong(0);
    private static WeakHashMap<Object, Long> FALLBACK_COMPARATOR_MAP = new WeakHashMap<>();

    private static synchronized long getFallbackObjectId(Object object) {
        return FALLBACK_COMPARATOR_MAP.computeIfAbsent(object, ignored -> FALLBACK_COUNTER.incrementAndGet());
    }

    /**
     * Provides an <i>arbitrary</i> but consistent total ordering over all objects. This comparison function will
     * return 0 if and only if a == b, and otherwise will return arbitrarily either -1 or 1, but will do so in a way
     * that results in a consistent total order.
     *
     * @param a
     * @param b
     * @return -1 or 1 (consistently) if a != b; 0 if a == b.
     */
    public static int compareObjectIdentity(Object a, Object b) {
        if (a == b) {
            return 0;
        }

        if (a == null) {
            return -1;
        }

        if (b == null) {
            return 1;
        }

        int hashCompare = Integer.compare(System.identityHashCode(a), System.identityHashCode(b));
        if (hashCompare != 0) {
            return hashCompare;
        }

        // Unfortunately these objects have identical hashcodes, so we need to find some other way to compare them.
        // We'll do this by mapping them to an incrementing counter, and comparing their assigned IDs instead.
        int fallbackCompare = Long.compare(getFallbackObjectId(a), getFallbackObjectId(b));
        if (fallbackCompare == 0) {
            throw new AssertionError("Failed to assign unique order to objects");
        }

        return fallbackCompare;
    }

    public static long saturatingAdd(long a, long b) {
        long r = a + b;

        if (a > 0 && b > 0 && r < a) {
            return Long.MAX_VALUE;
        }

        if (a < 0 && b < 0 && r > a) {
            return Long.MIN_VALUE;
        }

        // If the signs between a and b differ, overflow is impossible.

        return r;
    }

    /**
     * Comparator that performs a lexicographical comparison of byte arrays, treating them as unsigned.
     */
    public static class ComparingByteArrays implements Comparator<byte[]>, Serializable {
        // We don't really need to be serializable, but it doesn't hurt, and FindBugs gets annoyed if we're not.
        private static final long serialVersionUID = 0xdf641037ffe509e2L;

        @Override public int compare(byte[] o1, byte[] o2) {
            return new ComparingByteBuffers().compare(ByteBuffer.wrap(o1), ByteBuffer.wrap(o2));
        }
    }

    public static class ComparingByteBuffers implements Comparator<ByteBuffer>, Serializable {
        private static final long serialVersionUID = 0xa3c4a7300fbbf043L;

        @Override public int compare(ByteBuffer o1, ByteBuffer o2) {
            o1 = o1.slice();
            o2 = o2.slice();

            int commonLength = Math.min(o1.remaining(), o2.remaining());

            for (int i = 0; i < commonLength; i++) {
                // Perform zero-extension as we want to treat the bytes as unsigned
                int v1 = o1.get(i) & 0xFF;
                int v2 = o2.get(i) & 0xFF;

                if (v1 != v2) {
                    return v1 - v2;
                }
            }

            // The longer buffer is bigger (0x00 comes after end-of-buffer)
            return o1.remaining() - o2.remaining();
        }
    }

    /**
     * Throws {@link NullPointerException} with message {@code paramName} if {@code object} is null.
     *
     * @param object
     *            value to be null-checked
     * @param paramName
     *            message for the potential {@link NullPointerException}
     * @return {@code object}
     * @throws NullPointerException
     *             if {@code object} is null
     */
    public static <T> T assertNonNull(final T object, final String paramName) throws NullPointerException {
        if (object == null) {
            throw new NullPointerException(paramName + " must not be null");
        }
        return object;
    }

    /**
     * Returns a possibly truncated version of {@code arr} which is guaranteed to be exactly
     * {@code len} elements long. If {@code arr} is already exactly {@code len} elements long, then
     * {@code arr} is returned without copy or modification. If {@code arr} is longer than
     * {@code len}, then a truncated copy is returned. If {@code arr} is shorter than {@code len}
     * then this throws an {@link IllegalArgumentException}.
     */
    public static byte[] truncate(final byte[] arr, final int len) throws IllegalArgumentException {
        if (arr.length == len) {
            return arr;
        } else if (arr.length > len) {
            return Arrays.copyOf(arr, len);
        } else {
            throw new IllegalArgumentException("arr is not at least " + len + " elements long");
        }
    }

    public static SecureRandom getSecureRandom() {
        return LOCAL_RANDOM.get();
    }

    /**
     * Generate the AAD bytes to use when encrypting/decrypting content. The
     * generated AAD is a block of bytes containing the provided message
     * identifier, the string identifier, the sequence number, and the length of
     * the content.
     * 
     * @param messageId
     *            the unique message identifier for the ciphertext.
     * @param idString
     *            the string describing the type of content processed.
     * @param seqNum
     *            the sequence number.
     * @param len
     *            the length of the content.
     * @return
     *         the bytes containing the generated AAD.
     */
    static byte[] generateContentAad(final byte[] messageId, final String idString, final int seqNum, final long len) {
        final byte[] idBytes = idString.getBytes(StandardCharsets.UTF_8);
        final int aadLen = messageId.length + idBytes.length + Integer.SIZE / Byte.SIZE + Long.SIZE / Byte.SIZE;
        final ByteBuffer aad = ByteBuffer.allocate(aadLen);
    
        aad.put(messageId);
        aad.put(idBytes);
        aad.putInt(seqNum);
        aad.putLong(len);
    
        return aad.array();
    }

    static IllegalArgumentException cannotBeNegative(String field) {
        return new IllegalArgumentException(field + " cannot be negative");
    }

    /**
     * Equivalent to calling {@link ByteBuffer#flip()} but in a manner which is
     * safe when compiled on Java 9 or newer but used on Java 8 or older.
     */
    public static ByteBuffer flip(final ByteBuffer buff) {
        ((Buffer) buff).flip();
        return buff;
    }

    /**
     * Equivalent to calling {@link ByteBuffer#clear()} but in a manner which is
     * safe when compiled on Java 9 or newer but used on Java 8 or older.
     */
    public static ByteBuffer clear(final ByteBuffer buff) {
        ((Buffer) buff).clear();
        return buff;
    }

    /**
     * Equivalent to calling {@link ByteBuffer#position(int)} but in a manner which is
     * safe when compiled on Java 9 or newer but used on Java 8 or older.
     */
    public static ByteBuffer position(final ByteBuffer buff, final int newPosition) {
        ((Buffer) buff).position(newPosition);
        return buff;
    }

    /**
     * Equivalent to calling {@link ByteBuffer#limit(int)} but in a manner which is
     * safe when compiled on Java 9 or newer but used on Java 8 or older.
     */
    public static ByteBuffer limit(final ByteBuffer buff, final int newLimit) {
        ((Buffer) buff).limit(newLimit);
        return buff;
    }

    /**
     * Takes a Base64-encoded String, decodes it, and returns contents as a byte array.
     *
     * @param encoded Base64 encoded String
     * @return decoded data as a byte array
     */
    public static byte[] decodeBase64String(final String encoded) {
        return encoded.isEmpty() ? ArrayUtils.EMPTY_BYTE_ARRAY : Base64.decode(encoded);
    }

    /**
     * Takes data in a byte array, encodes them in Base64, and returns the result as a String.
     *
     * @param data The data to encode.
     * @return Base64 string that encodes the {@code data}.
     */
    public static String encodeBase64String(final byte[] data) {
        return Base64.toBase64String(data);
    }

    /**
     * Removes the leading zero sign byte from the byte array representation of a BigInteger (if present)
     * and left pads with zeroes to produce a byte array of the given length.
     * @param bigInteger The BigInteger to convert to a byte array
     * @param length The length of the byte array, must be at least
     *              as long as the BigInteger byte array without the sign byte
     * @return The byte array
     */
    public static byte[] bigIntegerToByteArray(final BigInteger bigInteger, final int length) {
        byte[] rawBytes = bigInteger.toByteArray();
        // If rawBytes is already the correct length, return it.
        if (rawBytes.length == length) {
            return rawBytes;
        }

        // If we're exactly one byte too large, but we have a leading zero byte, remove it and return.
        if(rawBytes.length == length + 1 && rawBytes[0] == 0) {
            return Arrays.copyOfRange(rawBytes, 1, rawBytes.length);
        }

        if (rawBytes.length > length) {
            throw new IllegalArgumentException("Length must be at least as long as the BigInteger byte array " +
                    "without the sign byte");
        }

        final byte[] paddedResult = new byte[length];
        System.arraycopy(rawBytes, 0, paddedResult, length - rawBytes.length, rawBytes.length);
        return paddedResult;
    }

    /**
     * Returns true if the prefix of the given length for the input arrays are equal.
     * This method will return as soon as the first difference is found, and is thus not constant-time.
     *
     * @param a      The first array.
     * @param b      The second array.
     * @param length The length of the prefix to compare.
     * @return True if the prefixes are equal, false otherwise.
     */
    public static boolean arrayPrefixEquals(final byte[] a, final byte[] b, final int length) {
        if (a == null || b == null || a.length < length || b.length < length) {
            return false;
        }
        for (int x = 0; x < length; x++) {
            if (a[x] != b[x]) {
                return false;
            }
        }
        return true;
    }
}