/**
 * Copyright 2016 Twitter. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package com.twitter.graphjet.algorithms.salsa;

import com.google.common.collect.Lists;

import org.junit.Before;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import com.twitter.graphjet.algorithms.BipartiteGraphTestHelper;
import com.twitter.graphjet.algorithms.filters.DirectInteractionsFilter;
import com.twitter.graphjet.algorithms.NodeInfo;
import com.twitter.graphjet.algorithms.filters.RequestedSetFilter;
import com.twitter.graphjet.algorithms.filters.ResultFilterChain;
import com.twitter.graphjet.algorithms.salsa.fullgraph.SalsaInternalState;
import com.twitter.graphjet.bipartite.api.BipartiteGraph;
import com.twitter.graphjet.stats.NullStatsReceiver;

import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;

public class SalsaNodeVisitorTest {
  private SalsaRequest salsaRequest;
  private SalsaInternalState salsaInternalState;
  private byte[] socialProofTypes;

  @Before
  public void setUp() {
    long queryNode = 1;
    BipartiteGraph bipartiteGraph = BipartiteGraphTestHelper.buildSmallTestBipartiteGraph();
    LongSet toBeFiltered = new LongOpenHashSet(8);
    int numIterations = 5;
    double resetProbability = 0.3;
    int numResults = 3;
    int numRandomWalks = 1000;
    int maxSocialProofSize = 2;
    int expectedNodesToHit = numRandomWalks * numIterations / 2;
    SalsaStats salsaStats = new SalsaStats();
    socialProofTypes = new byte[]{0, 1, 2, 3};
    ResultFilterChain resultFilterChain = new ResultFilterChain(Lists.newArrayList(
        new RequestedSetFilter(new NullStatsReceiver()),
        new DirectInteractionsFilter(bipartiteGraph, new NullStatsReceiver())
    ));

    salsaRequest = new SalsaRequestBuilder(queryNode)
        .withLeftSeedNodes(null)
        .withToBeFiltered(toBeFiltered)
        .withMaxNumResults(numResults)
        .withResetProbability(resetProbability)
        .withMaxRandomWalkLength(numIterations)
        .withNumRandomWalks(numRandomWalks)
        .withMaxSocialProofSize(maxSocialProofSize)
        .withValidSocialProofTypes(socialProofTypes)
        .withResultFilterChain(resultFilterChain)
        .build();

    salsaInternalState = new SalsaInternalState(
        bipartiteGraph, salsaStats, expectedNodesToHit);
    salsaInternalState.resetWithRequest(salsaRequest);
  }

  @Test
  public void testSimpleNodeVisitor() throws Exception {
    SalsaNodeVisitor.SimpleNodeVisitor simpleNodeVisitor =
        new SalsaNodeVisitor.SimpleNodeVisitor(
            salsaInternalState.getVisitedRightNodes());
    simpleNodeVisitor.resetWithRequest(salsaRequest);

    simpleNodeVisitor.visitRightNode(1, 2, (byte) 0, 0L, 1);
    simpleNodeVisitor.visitRightNode(2, 3, (byte) 0, 0L, 1);
    simpleNodeVisitor.visitRightNode(1, 3, (byte) 0, 0L, 1);

    Long2ObjectMap<NodeInfo> expectedVisitedRightNodesMap =
        new Long2ObjectOpenHashMap<NodeInfo>(2);
    expectedVisitedRightNodesMap.put(2, new NodeInfo(2, 1, 1));
    expectedVisitedRightNodesMap.put(3, new NodeInfo(3, 2, 1));

    assertEquals(expectedVisitedRightNodesMap, salsaInternalState.getVisitedRightNodes());
  }

  @Test
  public void testNodeVisitorWithSocialProof() throws Exception {
    SalsaNodeVisitor.NodeVisitorWithSocialProof nodeVisitorWithSocialProof =
        new SalsaNodeVisitor.NodeVisitorWithSocialProof(
            salsaInternalState.getVisitedRightNodes());
    nodeVisitorWithSocialProof.resetWithRequest(salsaRequest);

    nodeVisitorWithSocialProof.visitRightNode(1, 2, (byte) 0, 0L, 1);
    nodeVisitorWithSocialProof.visitRightNode(2, 3, (byte) 0, 0L, 1);
    nodeVisitorWithSocialProof.visitRightNode(1, 3, (byte) 0, 0L, 1);

    NodeInfo node2 = new NodeInfo(2, 1, 1);
    NodeInfo node3 = new NodeInfo(3, 2, 1);
    assertTrue(node3.addToSocialProof(2, (byte) 0, 0L, 1));

    Long2ObjectMap<NodeInfo> expectedVisitedRightNodesMap =
        new Long2ObjectOpenHashMap<NodeInfo>(2);
    expectedVisitedRightNodesMap.put(2, node2);
    expectedVisitedRightNodesMap.put(3, node3);

    assertEquals(expectedVisitedRightNodesMap, salsaInternalState.getVisitedRightNodes());
  }
}