/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.parallelism;


import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * This ParallelInference implementation provides inference functionality without launching additional threads, so inference happens in the calling thread.
 *
 * To instantiate this implementation one should use InferenceMode.INPLACE in ParallelInference.Builder
 *
 * PLEASE NOTE: This implementation does not create additional threads
 * PLEASE NOTE: This implementation uses shared parameters for models on per-device basis
 *
 * @author [email protected]
 */
@Slf4j
public class InplaceParallelInference extends ParallelInference {
    protected List<ModelHolder> holders = new CopyOnWriteArrayList<>();
    protected ModelSelector selector = new ModelSelector();

    protected final Object locker = new Object();

    @Override
    protected void init() {
        for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) {
            val h = ModelHolder.builder()
                    .sourceModel(model)
                    .workers(workers)
                    .loadBalanceMode(loadBalanceMode)
                    .targetDeviceId(e)
                    .rootDevice(e == Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue())
                    .build();
            h.init();

            // adding for simplified access
            holders.add(h);

            // and adding it to actual
            selector.addModelHolder(e, h);
        }
    }

    @Override
    public synchronized void updateModel(@NonNull Model model) {
        for (val h:holders)
            h.updateModel(model);
    }

    @Override
    protected synchronized Model[] getCurrentModelsFromWorkers() {
        val models = new Model[holders.size()];
        int cnt = 0;
        for (val h:holders) {
            models[cnt++] = h.sourceModel;
        }

        return models;
    }

    @Override
    public INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
        return selector.output(input, inputMasks);
    }

    /**
     * This method does forward pass and returns output provided by OutputAdapter
     *
     * @param adapter
     * @param input
     * @param inputMasks
     * @param <T>
     * @return
     */
    public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) {
        val holder = selector.getModelForThisThread();
        Model model = null;
        boolean acquired = false;
        try {
            model = holder.acquireModel();
            acquired = true;
            return adapter.apply(model, input, inputMasks, labelsMasks);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } finally {
            if (model != null && acquired)
                holder.releaseModel(model);
        }
    }


    protected static class ModelSelector {
        // this map stores collection of shared
        protected Map<Integer, ModelHolder> map = new HashMap<>();

        protected final LoadBalanceMode loadBalanceMode;

        public ModelSelector() {
            this(LoadBalanceMode.ROUND_ROBIN);
        }

        public ModelSelector(LoadBalanceMode loadBalanceMode) {
            this.loadBalanceMode = loadBalanceMode;
        }

        protected void addModelHolder(@NonNull Integer device, @NonNull ModelHolder holder) {
            map.put(device, holder);
        }

        public ModelHolder getModelForThread(long threadId) {
            // first of all we get mapped device for this thread
            val device = Nd4j.getAffinityManager().getDeviceForThread(threadId);

            // each device has it's own queue
            val q = map.get(device);

            // and we're returning holder right away
            return q;
        }

        public INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
            return getModelForThisThread().output(input, inputMasks);
        }

        public ModelHolder getModelForThisThread() {
            return getModelForThread(Thread.currentThread().getId());
        }
    }

    @NoArgsConstructor
    @AllArgsConstructor
    @lombok.Builder
    protected static class ModelHolder {
        protected Model sourceModel;
        @lombok.Builder.Default protected int workers = 4;
        @lombok.Builder.Default protected List<Model> replicas = new ArrayList<>();
        @lombok.Builder.Default protected boolean rootDevice = true;
        @lombok.Builder.Default protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.ROUND_ROBIN;
        protected int targetDeviceId;

        protected final AtomicLong position = new AtomicLong(0);
        protected final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();

        // this queue is used in FIFO mode
        protected final BlockingQueue<Model> queue = new LinkedBlockingQueue<>();

        @lombok.Builder.Default protected transient boolean isCG = false;
        @lombok.Builder.Default protected transient boolean isMLN = false;


        protected synchronized void init() {
            if (workers < 1)
                throw new ND4JIllegalStateException("Workers must be positive value");

            replicas.clear();

            isCG = sourceModel instanceof ComputationGraph;
            isMLN = sourceModel instanceof MultiLayerNetwork;

            // we clone params only if we're not on the same device
            val params = rootDevice ? sourceModel.params() : sourceModel.params().unsafeDuplication(true);

            // and moving it to specified device (only if NOT root
            if (!rootDevice)
                Nd4j.getAffinityManager().replicateToDevice(targetDeviceId, params);

            for (int e = 0; e < workers; e++) {
                if (sourceModel instanceof ComputationGraph) {
                    // building configuration with shared parameters
                    val model = new ComputationGraph(ComputationGraphConfiguration.fromJson(((ComputationGraph) sourceModel).getConfiguration().toJson()));
                    model.init(params, false);
                    Nd4j.getExecutioner().commit();

                    // storing model for future reuse
                    replicas.add(model);

                    if (loadBalanceMode == LoadBalanceMode.FIFO)
                        queue.add(model);
                } else if (sourceModel instanceof MultiLayerNetwork) {
                    val model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getLayerWiseConfigurations().toJson()));
                    model.init(params, false);
                    Nd4j.getExecutioner().commit();

                    replicas.add(model);

                    if (loadBalanceMode == LoadBalanceMode.FIFO)
                        queue.add(model);
                }
            }
        }


        protected Model acquireModel() throws InterruptedException {
            try {
                modelLock.readLock().lock();

                switch (loadBalanceMode) {
                    case FIFO: {
                            return queue.take();
                        }
                    case ROUND_ROBIN:
                        return replicas.get((int) (position.getAndIncrement() % replicas.size()));
                    default:
                        throw new ND4JIllegalStateException("Unknown LoadBalanceMode was specified: [" + loadBalanceMode + "]");
                }
            } finally {
                modelLock.readLock().unlock();
            }
        }

        protected void releaseModel(Model model) {
            try {
                modelLock.readLock().lock();

                switch (loadBalanceMode) {
                    case FIFO:
                        queue.add(model);
                        break;
                    case ROUND_ROBIN:
                        break;
                    default:
                        throw new ND4JIllegalStateException("Unknown LoadBalanceMode was specified: [" + loadBalanceMode + "]");
                }
            } finally {
                modelLock.readLock().unlock();
            }
        }

        protected INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
            try {
                modelLock.readLock().lock();
                if (isCG) {
                    // acquiring model from pool
                    val model = acquireModel();

                    // doing inference
                    INDArray[] output;
                    try{
                        output = ((ComputationGraph) model).output(false, input, inputMasks);
                    } finally {
                        // releasing model
                        releaseModel(model);
                    }
                    return output;
                } else if (isMLN) {
                    if (input.length > 1 || (inputMasks != null && inputMasks.length > 1))
                        throw new ND4JIllegalStateException("MultilayerNetwork can't have multiple inputs");

                    val model = acquireModel();
                    INDArray result;
                    try {
                        result = ((MultiLayerNetwork) model).output(input[0], false, (inputMasks == null ? null : inputMasks[0]), null);
                    } finally {
                        releaseModel(model);
                    }
                    return new INDArray[]{result};
                } else
                    throw new UnsupportedOperationException();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            } finally {
                modelLock.readLock().unlock();
            }
        }

        protected void updateModel(@NonNull Model model) {
            try {
                modelLock.writeLock().lock();

                this.sourceModel = model;

                init();
            } finally {
                modelLock.writeLock().unlock();
            }
        }

    }
}