package dl.nn;

import org.jblas.DoubleMatrix;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

public class NeuralCoordinator
{
    private List<NeuralNetwork> m_networks;
    private NN_parameters m_params;
    private String m_modelTitle;
    private String m_model_path;
    private String m_train_path;
    private String m_test_path;
    private double m_total_nodes;
    private final int update_threshold = 20;

    public NeuralCoordinator(String model_title, String title, String dataset, NN_parameters params, List<NeuronLayer> layers, LinkedList<HiddenLayer> hiddenLayers, double L2, ICostFunction cf) throws IOException
    {
        m_modelTitle = model_title;
        m_model_path = Util.DATAPATH + dataset + "/" + Util.MODEL + title + "_" + model_title;
        m_train_path = Util.DATAPATH + dataset + "/" + Util.TRAIN + title;
        m_test_path = Util.DATAPATH + dataset + "/" + Util.TEST + title;

        for(NeuronLayer layer : layers)
        {
            m_total_nodes += layer.m_layer_size;
        }

        m_params = params;
        m_networks = new ArrayList<>(Util.LAYER_THREADS);
        m_networks.add(new NeuralNetwork(params, layers, hiddenLayers, L2, cf));
        for(int idx = 1; idx < Util.LAYER_THREADS; ++idx)
        {
            LinkedList<HiddenLayer> hiddenLayers1 = new LinkedList<>();
            hiddenLayers.forEach(e -> hiddenLayers1.add(e.clone()));

            List<NeuronLayer> layers1 = new LinkedList<>();
            layers1.addAll(hiddenLayers1);
            layers1.add(layers.get(layers.size()-1).clone());
            m_networks.add(new NeuralNetwork(params, layers1, hiddenLayers1, L2, cf));
        }
    }

    private List<Integer> initIndices(int length)
    {
        List<Integer> indices = new ArrayList<>();
        for(int idx = 0; idx < length; ++idx)
        {
            indices.add(idx);
        }
        return indices;
    }

    private void shuffle(List<Integer> indices)
    {
        for(int idx = 0; idx < indices.size(); ++idx)
        {
            int rand = Util.rand.nextInt(indices.size());
            int value = indices.get(idx);
            indices.set(idx, indices.get(rand));
            indices.set(rand, value);
        }
    }

    public void test(List<DoubleMatrix> data, double[] labels)
    {
        List<int[]> test_hashes = m_params.computeHashes(data);
        System.out.println("Finished Pre-Computing Training Hashes");
        System.out.println(m_networks.get(0).test(test_hashes, data, labels));
    }

    // training data, training labels
    public void train(final int max_epoch, List<DoubleMatrix> data, double[] labels, List<DoubleMatrix> test_data, double[] test_labels) throws Exception
    {
        assert(data.size() == labels.length);
        assert(test_data.size() == test_labels.length);

        List<int[]> input_hashes = m_params.computeHashes(data);
        System.out.println("Finished Pre-Computing Training Hashes");

        List<int[]> test_hashes = m_params.computeHashes(test_data);
        System.out.println("Finished Pre-Computing Testing Hashes");

        List<Integer> data_idx = initIndices(labels.length);
        final int m_examples_per_thread = data.size() / (Util.UPDATE_SIZE * Util.LAYER_THREADS);
        assert(data_idx.size() == labels.length);

        BufferedWriter train_writer = new BufferedWriter(new FileWriter(m_train_path, true));
        BufferedWriter test_writer = new BufferedWriter(new FileWriter(m_test_path, true));
        for(int epoch_count = 0; epoch_count < max_epoch; ++epoch_count)
        {
            m_params.clear_gradient();
            shuffle(data_idx);
            int count = 0;
            while(count < data_idx.size())
            {
                List<Thread> threads = new LinkedList<>();
                for(NeuralNetwork network : m_networks)
                {
                    if(count < data_idx.size())
                    {
                        int start = count;
                        count = Math.min(data_idx.size(), count + m_examples_per_thread);
                        int end = count;

                        Thread t = new Thread()
                        {
                            @Override
                            public void run()
                            {
                                for (int pos = start; pos < end; ++pos)
                                {
                                    network.execute(input_hashes.get(pos), data.get(pos), labels[pos], true);
                                }
                            }
                        };
                        t.start();
                        threads.add(t);
                    }
                }
                Util.join(threads);
                if(epoch_count <= update_threshold && epoch_count % (epoch_count / 10 + 1) == 0)
                {
                    m_params.rebuildTables();
                }

            }

            // Console Debug Output
            int epoch = m_params.epoch_offset() + epoch_count;
            //m_networks.stream().forEach(e -> e.updateHashTables(labels.length / Util.LAYER_THREADS));
            double activeNodes = calculateActiveNodes(m_total_nodes * data.size());
            double test_accuracy = m_networks.get(0).test(test_hashes, test_data, test_labels);
            System.out.println("Epoch " + epoch + " Accuracy: " + test_accuracy);

            // Test Output
            DecimalFormat df = new DecimalFormat("#.###");
            df.setRoundingMode(RoundingMode.FLOOR);
            test_writer.write(m_modelTitle + " " + epoch + " " + df.format(activeNodes) + " " + test_accuracy);
            test_writer.newLine();

            // Train Output
            train_writer.write(m_modelTitle + " " + epoch + " " + df.format(activeNodes) + " " + calculateTrainAccuracy(data.size()));
            train_writer.newLine();

            test_writer.flush();
            train_writer.flush();

            m_params.timeStep();
        }
        test_writer.close();
        train_writer.close();
        save_model(max_epoch, m_model_path);
    }

    public void save_model(int epoch, String path) throws IOException
    {
        m_params.save_model(epoch, Util.writerBZ2(path));
    }

    private double calculateTrainAccuracy(double size)
    {
        double count = 0;
        for(NeuralNetwork network : m_networks)
        {
            count += network.m_train_correct;
            network.m_train_correct = 0;
        }
        return count / size;
    }

    private double calculateActiveNodes(double total)
    {
        long active = 0;
        for(NeuralNetwork network : m_networks)
        {
            active += network.calculateActiveNodes();
        }
        return active / total;
    }

    private long calculateMultiplications()
    {
        long total = 0;
        for(NeuralNetwork network : m_networks)
        {
            total += network.calculateMultiplications();
        }
        return total;
    }
}