package io.whz.synapse.util;

import android.support.annotation.NonNull;
import android.support.annotation.Nullable;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;

import io.whz.synapse.matrix.Matrix;
import io.whz.synapse.pojo.dao.DBModel;
import io.whz.synapse.pojo.neural.Model;

public class DbHelper {
    public static Model dbModel2Model(@NonNull DBModel dbModel) {
        Precondition.checkNotNull(dbModel);

        final Model model = new Model();

        model.setId(dbModel.getId());
        model.setName(dbModel.getName());
        model.setCreatedTime(dbModel.getCreatedTime());
        model.setLearningRate(dbModel.getLearningRate());
        model.setEpochs(dbModel.getEpochs());
        model.setStepEpoch(dbModel.getEpochs());
        model.setDataSize(dbModel.getDataSize());
        model.setTimeUsed(dbModel.getTimeUsed());
        model.setEvaluate(dbModel.getEvaluate());
        model.setHiddenSizes(byteArray2IntArray(dbModel.getHiddenSizeBytes()));
        model.setAccuracies(byteArray2DoubleArray(dbModel.getAccuracyBytes()));
        model.setBiases(byteArray2MatrixArray(dbModel.getBiasBytes()));
        model.setWeights(byteArray2MatrixArray(dbModel.getWeightBytes()));

        return model;
    }

    public static DBModel model2DBModel(@NonNull Model model) {
        Precondition.checkNotNull(model);

        final DBModel dbModel = new DBModel();

        dbModel.setId(model.getId());
        dbModel.setName(model.getName());
        dbModel.setCreatedTime(model.getCreatedTime());
        dbModel.setLearningRate(model.getLearningRate());
        dbModel.setEpochs(model.getEpochs());
        dbModel.setDataSize(model.getDataSize());
        dbModel.setTimeUsed(model.getTimeUsed());
        dbModel.setEvaluate(model.getEvaluate());
        dbModel.setHiddenSizeBytes(convert2ByteArray(model.getHiddenSizes()));
        dbModel.setAccuracyBytes(convert2ByteArray(model.getAccuracies()));
        dbModel.setBiasBytes(convert2ByteArray(model.getBiases()));
        dbModel.setWeightBytes(convert2ByteArray(model.getWeights()));

        return dbModel;
    }

    @Nullable
    private static byte[] convert2ByteArray(int... array) {
        if (array == null) {
            return null;
        }

        ByteBuffer buffer = null;

        try {
            buffer = ByteBuffer.allocate(array.length << 2);

            for (int i : array) {
                buffer.putInt(i);
            }
        } catch (Exception e) {
            e.printStackTrace();
            buffer = null;
        }

        return buffer == null ? null : buffer.array();
    }

    @Nullable
    private static byte[] convert2ByteArray(double... array) {
        if (array == null) {
            return  null;
        }

        ByteBuffer buffer = null;

        try {
            buffer = ByteBuffer.allocate(array.length << 3);

            for (double i : array) {
                buffer.putDouble(i);
            }
        } catch (Exception e) {
            e.printStackTrace();

            buffer = null;
        }

        return buffer == null ? null : buffer.array();
    }

    @Nullable
    private static byte[] convert2ByteArray(Matrix... matrices) {
        if (matrices == null) {
            return null;
        }

        ByteBuffer buffer = null;
        int sum = 0;

        sum += 4;

        for (Matrix matrix : matrices) {
            sum += calMatrixLen(matrix);
        }

        try {
            buffer = ByteBuffer.allocate(sum);

            buffer.putInt(matrices.length);

            for (Matrix matrix : matrices) {
                buffer.putInt(matrix.getRow());
                buffer.putInt(matrix.getCol());

                final double[] doubles = matrix.getArray();

                for (double d : doubles) {
                    buffer.putDouble(d);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        return buffer == null ? null : buffer.array();
    }

    private static int calMatrixLen(@NonNull Matrix matrix) {
        int sum = 0;

        sum += 8;
        sum += (matrix.getArray().length << 3);

        return sum;
    }

    @Nullable
    private static Matrix[] byteArray2MatrixArray(byte... array) {
        if (array == null) {
            return null;
        }

        Matrix[] res = null;

        try {
            final ByteBuffer buffer = ByteBuffer.wrap(array);

            final int len = buffer.getInt();
            res = new Matrix[len];

            for (int i = 0; i < len; ++i) {
                final int row = buffer.getInt();
                final int col = buffer.getInt();
                final double[] doubles = new double[row * col];

                for (int j = 0, jLen = doubles.length; j < jLen; ++j) {
                    doubles[j] = buffer.getDouble();
                }

                res[i] = Matrix.array(doubles, row);
            }
        } catch (Exception e) {
            e.printStackTrace();

            res = null;
        }

        return res;
    }

    @Nullable
    private static int[] byteArray2IntArray(byte... array) {
        if (array == null) {
            return null;
        }

        final int[] res = new int[array.length >> 2];

        try {
            final IntBuffer buffer = ByteBuffer.wrap(array).asIntBuffer();

            for (int i = 0, iLen = res.length; i < iLen; ++i) {
                res[i] = buffer.get();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        return res;
    }

    @Nullable
    private static double[] byteArray2DoubleArray(byte... array) {
        if (array == null) {
            return null;
        }

        final int len = array.length;
        final double[] res = new double[len >> 3];

        try {
            final DoubleBuffer buffer = ByteBuffer.wrap(array).asDoubleBuffer();

            for (int i = 0, iLen = len >> 3; i < iLen; ++i) {
                res[i] = buffer.get();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        return res;
    }
}