/**
 * Copyright 2016 Twitter. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package com.twitter.graphjet.algorithms.salsa;

import java.util.Random;

import com.google.common.annotations.VisibleForTesting;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.twitter.graphjet.bipartite.api.LeftIndexedBipartiteGraph;

import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.LongSet;

/**
 * This class implements the logic of SALSA iterations.
 */
public class SalsaIterations<T extends LeftIndexedBipartiteGraph> {
  private static final Logger LOG = LoggerFactory.getLogger("graph");

  private final CommonInternalState<T> salsaInternalState;
  private final SalsaStats salsaStats;
  private final SingleSalsaIteration leftSalsaIteration;
  private final SingleSalsaIteration rightSalsaIteration;
  private final SingleSalsaIteration finalSalsaIteration;

  /**
   * Initialize state needed to run SALSA iterations by plugging in different kinds of iterations.
   *
   * @param salsaInternalState   is the input state for the iterations to run
   * @param leftSalsaIteration   contains the logic of running the left-to-right iteration
   * @param rightSalsaIteration  contains the logic of running the right-to-left iteration
   * @param finalSalsaIteration  contains the logic of running the final left-to-right iteration
   */
  public SalsaIterations(
      CommonInternalState<T> salsaInternalState,
      SingleSalsaIteration leftSalsaIteration,
      SingleSalsaIteration rightSalsaIteration,
      SingleSalsaIteration finalSalsaIteration) {
    this.salsaInternalState = salsaInternalState;
    this.salsaStats = salsaInternalState.getSalsaStats();
    this.leftSalsaIteration = leftSalsaIteration;
    this.rightSalsaIteration = rightSalsaIteration;
    this.finalSalsaIteration = finalSalsaIteration;
  }

  /**
   * Main entry point to run the SALSA iterations. We do a monte-carlo implementation of the SALSA
   * algorithm in that we run multiple independent random walks from the queryNode, which then
   * implies that the # visits to nodes can be used for weighting. The particular implementation
   * here actually progresses all of the random walks simultaneously one step at a time. Thus, we
   * start on the left, run one step of the random walk for all walks, then start on the right, run
   * one step of the random walk for all walks and so on. The algorithm maintains visit counters
   * for nodes on the right, which are later used for picking top nodes in
   * {@link SalsaSelectResults}.
   *
   * @param salsaRequest        is the new incoming salsa request
   * @param random              is used for making all the random choices in SALSA
   */
  public void runSalsaIterations(SalsaRequest salsaRequest, Random random) {
    LOG.info("SALSA: starting to reset internal state");
    resetWithRequest(salsaRequest, random);
    LOG.info("SALSA: done resetting internal state");

    seedLeftSideForFirstIteration();
    LOG.info("SALSA: done seeding");
    boolean isForwardIteration = true;
    SingleSalsaIteration singleSalsaIteration = leftSalsaIteration;

    for (int i = 0; i < salsaInternalState.getSalsaRequest().getMaxRandomWalkLength(); i++) {
      if (isForwardIteration) {
        singleSalsaIteration.runSingleIteration();
        singleSalsaIteration = rightSalsaIteration;
      } else {
        if (i < salsaInternalState.getSalsaRequest().getMaxRandomWalkLength() - 2) {
          singleSalsaIteration.runSingleIteration();
          singleSalsaIteration = leftSalsaIteration;
        } else {
          singleSalsaIteration.runSingleIteration();
          singleSalsaIteration = finalSalsaIteration;
        }
      }
      isForwardIteration = !isForwardIteration;
    }
  }

  @VisibleForTesting
  protected void seedLeftSideForFirstIteration() {
    long queryNode = salsaInternalState.getSalsaRequest().getQueryNode();
    salsaStats.setNumDirectNeighbors(
        salsaInternalState.getBipartiteGraph().getLeftNodeDegree(queryNode));

    Long2DoubleMap seedNodesWithWeight =
        salsaInternalState.getSalsaRequest().getLeftSeedNodesWithWeight();
    LongSet nonZeroSeedSet = salsaInternalState.getNonZeroSeedSet();

    double totalWeight = 0.0;
    for (Long2DoubleMap.Entry entry : seedNodesWithWeight.long2DoubleEntrySet()) {
      if (salsaInternalState.getBipartiteGraph().getLeftNodeDegree(entry.getLongKey())
          > 0) {
        totalWeight += entry.getDoubleValue();
        nonZeroSeedSet.add(entry.getLongKey());
      }
    }

    // If there is a pre-specified weight, we let it take precedence, but if not, then we reset
    // weights in accordance with the fraction of weight requested for the query node.
    if (!seedNodesWithWeight.containsKey(queryNode)
        && salsaInternalState.getBipartiteGraph().getLeftNodeDegree(queryNode) > 0) {
      double queryNodeWeight = 1.0;
      if (totalWeight > 0.0) {
        queryNodeWeight =
            totalWeight * salsaInternalState.getSalsaRequest().getQueryNodeWeightFraction()
                / (1.0 - salsaInternalState.getSalsaRequest().getQueryNodeWeightFraction());
      }
      seedNodesWithWeight.put(queryNode, queryNodeWeight);
      totalWeight += queryNodeWeight;
      nonZeroSeedSet.add(queryNode);
    }

    for (long leftNode : nonZeroSeedSet) {
      int numWalksToStart = (int) Math.ceil(
          seedNodesWithWeight.get(leftNode) / totalWeight
              * salsaInternalState.getSalsaRequest().getNumRandomWalks());
        salsaInternalState.getCurrentLeftNodes().put(leftNode, numWalksToStart);
    }

    salsaStats.setNumSeedNodes(salsaInternalState.getCurrentLeftNodes().size());
  }

  /**
   * Resets all internal state to answer new incoming request.
   *
   * @param salsaRequest        is the new incoming salsa request
   * @param random              is used for making all the random choices in SALSA
   */
  @VisibleForTesting
  protected void resetWithRequest(SalsaRequest salsaRequest, Random random) {
    salsaInternalState.resetWithRequest(salsaRequest);
    leftSalsaIteration.resetWithRequest(salsaRequest, random);
    rightSalsaIteration.resetWithRequest(salsaRequest, random);
    finalSalsaIteration.resetWithRequest(salsaRequest, random);
  }
}