/**
 * Copyright 2016 Twitter. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package com.twitter.graphjet.algorithms.salsa;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;

import com.google.common.collect.Lists;

import org.junit.Test;

import static org.junit.Assert.assertArrayEquals;

import com.twitter.graphjet.algorithms.RecommendationInfo;
import com.twitter.graphjet.algorithms.filters.RequestedSetFilter;
import com.twitter.graphjet.algorithms.filters.ResultFilter;
import com.twitter.graphjet.algorithms.filters.ResultFilterChain;
import com.twitter.graphjet.algorithms.StaticBipartiteGraph;
import com.twitter.graphjet.algorithms.filters.TweetCardFilter;
import com.twitter.graphjet.algorithms.TweetIDMask;
import com.twitter.graphjet.algorithms.counting.tweet.TweetRecommendationInfo;
import com.twitter.graphjet.algorithms.salsa.fullgraph.Salsa;
import com.twitter.graphjet.bipartite.api.BipartiteGraph;
import com.twitter.graphjet.stats.NullStatsReceiver;

import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;

public class SalsaBitmaskTest {

  private static final TweetIDMask TWEET_ID_MASK = new TweetIDMask();
  private static long tweetNode = 11 | TweetIDMask.TWEET;
  private static long summaryNode = 12 | TweetIDMask.SUMMARY;
  private static long photoNode = 13 | TweetIDMask.PHOTO;
  private static long playerNode = 14 | TweetIDMask.PLAYER;
  private static long promotionNode = 15 | TweetIDMask.PROMOTION;
  private StaticBipartiteGraph buildTestGraph() {
    Long2ObjectMap<LongList> leftSideGraph = new Long2ObjectOpenHashMap<LongList>(3);
    leftSideGraph.put(1, new LongArrayList(new long[]{2, 3, 4, 5}));
    leftSideGraph.put(2, new LongArrayList(new long[]{tweetNode, summaryNode, photoNode, playerNode,
        2, 3}));
    leftSideGraph.put(3, new LongArrayList(new long[]{tweetNode, summaryNode, photoNode, playerNode,
        promotionNode, 4, 5}));

    Long2ObjectMap<LongList> rightSideGraph = new Long2ObjectOpenHashMap<LongList>(10);
    rightSideGraph.put(2, new LongArrayList(new long[]{1, 2}));
    rightSideGraph.put(3, new LongArrayList(new long[]{1, 2}));
    rightSideGraph.put(4, new LongArrayList(new long[]{1, 3}));
    rightSideGraph.put(5, new LongArrayList(new long[]{1, 3}));
    rightSideGraph.put(tweetNode, new LongArrayList(new long[]{2, 3}));
    rightSideGraph.put(summaryNode, new LongArrayList(new long[]{2, 3}));
    rightSideGraph.put(photoNode, new LongArrayList(new long[]{2, 3}));
    rightSideGraph.put(playerNode, new LongArrayList(new long[]{2, 3}));
    rightSideGraph.put(promotionNode, new LongArrayList(new long[]{3}));

    return new StaticBipartiteGraph(leftSideGraph, rightSideGraph);

  }

  private void testFilter(TweetCardFilter filter, Long[] expected) {
    Random random = new Random(918324701982347L);
    long queryNode = 1;
    BipartiteGraph bipartiteGraph = buildTestGraph();
    LongSet toBeFiltered = new LongOpenHashSet(new long[]{2, 3, 4, 5});
    int numIterations = 5;
    double resetProbability = 0.3;
    int numResults = 20;
    int numRandomWalks = 1000;
    int maxSocialProofSize = 2;
    int expectedNodesToHit = numRandomWalks * numIterations / 2;
    SalsaStats salsaStats = new SalsaStats(1, 4, 3, 6, 1, 3, 0);


    ResultFilterChain resultFilterChain = new ResultFilterChain(Lists.<ResultFilter>newArrayList(
        new RequestedSetFilter(new NullStatsReceiver()),
        filter
    ));

    SalsaRequest salsaRequest =
        new SalsaRequestBuilder(queryNode)
            .withLeftSeedNodes(null)
            .withToBeFiltered(toBeFiltered)
            .withMaxNumResults(numResults)
            .withResetProbability(resetProbability)
            .withMaxRandomWalkLength(numIterations)
            .withNumRandomWalks(numRandomWalks)
            .withMaxSocialProofSize(maxSocialProofSize)
            .withResultFilterChain(resultFilterChain)
            .build();

    SalsaResponse salsaResponse = new Salsa(
        bipartiteGraph,
        expectedNodesToHit,
        new NullStatsReceiver())
        .computeRecommendations(salsaRequest, random);

    Iterator<RecommendationInfo> recs = salsaResponse.getRankedRecommendations().iterator();
    ArrayList<Long> rets = new ArrayList<Long>();
    while (recs.hasNext()) {
      rets.add(((TweetRecommendationInfo) recs.next()).getRecommendation());
    }
    assertArrayEquals(expected, rets.toArray(new Long[]{}));
  }


  @Test
  public void testTweetsOnly() {
    Long[] tweets = new Long[]{TWEET_ID_MASK.restore(tweetNode)};

    testFilter(new TweetCardFilter(true, false, false, false, false, new NullStatsReceiver()),
        tweets);
  }

  @Test
  public void testSummaryOnly() {
    Long[] summary = new Long[]{TWEET_ID_MASK.restore(summaryNode)};

    testFilter(new TweetCardFilter(false, true, false, false, false, new NullStatsReceiver()),
        summary);
  }

  @Test
  public void testPhotoOnly() {
    Long[] photo = new Long[]{TWEET_ID_MASK.restore(photoNode)};

    testFilter(new TweetCardFilter(false, false, true, false, false, new NullStatsReceiver()),
        photo);
  }

  @Test
  public void testPlayerOnly() {
    Long[] player = new Long[]{TWEET_ID_MASK.restore(playerNode)};

    testFilter(new TweetCardFilter(false, false, false, true, false, new NullStatsReceiver()),
        player);
  }

  @Test
  public void testPromotionOnly() {
    Long[] promotion = new Long[]{TWEET_ID_MASK.restore(promotionNode)};

    testFilter(new TweetCardFilter(false, false, false, false, true, new NullStatsReceiver()),
        promotion);
  }
}