/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.drill.exec.physical.impl;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import org.apache.commons.lang3.StringUtils;
import org.apache.drill.PlanTestBase;
import org.apache.drill.TestBuilder;
import org.apache.drill.exec.physical.base.Exchange;
import org.apache.drill.exec.physical.config.UnorderedDeMuxExchange;
import org.apache.drill.exec.physical.config.HashToRandomExchange;
import org.apache.drill.exec.planner.PhysicalPlanReader;
import org.apache.drill.exec.planner.fragment.Fragment;
import org.apache.drill.exec.planner.fragment.Fragment.ExchangeFragmentPair;
import org.apache.drill.exec.planner.fragment.PlanningSet;
import org.apache.drill.exec.planner.fragment.SimpleParallelizer;
import org.apache.drill.exec.pop.PopUnitTestBase;
import org.apache.drill.exec.proto.BitControl.PlanFragment;
import org.apache.drill.exec.proto.BitControl.QueryContextInformation;
import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
import org.apache.drill.exec.proto.UserBitShared;
import org.apache.drill.exec.proto.UserBitShared.QueryId;
import org.apache.drill.exec.rpc.user.UserSession;
import org.apache.drill.exec.server.DrillbitContext;
import org.apache.drill.exec.server.options.OptionList;
import org.apache.drill.exec.util.Utilities;
import org.apache.drill.exec.work.QueryWorkUnit;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;

import static org.apache.drill.exec.planner.physical.HashPrelUtil.HASH_EXPR_NAME;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
 * This test starts a Drill cluster with CLUSTER_SIZE nodes and generates data for test tables.
 *
 * Tests queries involve HashToRandomExchange (group by and join) and test the following.
 *   1. Plan that has mux and demux exchanges inserted
 *   2. Run the query and check the output record count
 *   3. Take the plan we got in (1), use SimpleParallelizer to get PlanFragments and test that the number of
 *   partition senders in a major fragment is not more than the number of Drillbit nodes in cluster and there exists
 *   at most one partition sender per Drillbit.
 */
public class TestLocalExchange extends PlanTestBase {

  public static TemporaryFolder testTempFolder = new TemporaryFolder();

  private final static int CLUSTER_SIZE = 3;
  private final static String MUX_EXCHANGE = "\"unordered-mux-exchange\"";
  private final static String DEMUX_EXCHANGE = "\"unordered-demux-exchange\"";
  private final static String MUX_EXCHANGE_CONST = "unordered-mux-exchange";
  private final static String DEMUX_EXCHANGE_CONST = "unordered-demux-exchange";
  private static final String HASH_EXCHANGE = "hash-to-random-exchange";
  private final static UserSession USER_SESSION = UserSession.Builder.newBuilder()
      .withCredentials(UserBitShared.UserCredentials.newBuilder().setUserName("foo").build())
      .build();

  private static final SimpleParallelizer PARALLELIZER = new SimpleParallelizer(
      1 /*parallelizationThreshold (slice_count)*/,
      6 /*maxWidthPerNode*/,
      1000 /*maxGlobalWidth*/,
      1.2 /*affinityFactor*/);

  private final static int NUM_DEPTS = 40;
  private final static int NUM_EMPLOYEES = 1000;
  private final static int NUM_MNGRS = 1;
  private final static int NUM_IDS = 1;

  private static String empTableLocation;
  private static String deptTableLocation;

  private static String groupByQuery;
  private static String joinQuery;

  private static String[] joinQueryBaselineColumns;
  private static String[] groupByQueryBaselineColumns;

  private static List<Object[]> groupByQueryBaselineValues;
  private static List<Object[]> joinQueryBaselineValues;

  @BeforeClass
  public static void setupClusterSize() {
    updateTestCluster(CLUSTER_SIZE, null);
  }

  @BeforeClass
  public static void setupTempFolder() throws IOException {
    testTempFolder.create();
  }

  /**
   * Generate data for two tables. Each table consists of several JSON files.
   */
  @BeforeClass
  public static void generateTestDataAndQueries() throws Exception {
    // Table 1 consists of two columns "emp_id", "emp_name" and "dept_id"
    empTableLocation = testTempFolder.newFolder().getAbsolutePath();

    // Write 100 records for each new file
    final int empNumRecsPerFile = 100;
    for(int fileIndex=0; fileIndex<NUM_EMPLOYEES/empNumRecsPerFile; fileIndex++) {
      File file = new File(empTableLocation + File.separator + fileIndex + ".json");
      PrintWriter printWriter = new PrintWriter(file);
      for (int recordIndex = fileIndex*empNumRecsPerFile; recordIndex < (fileIndex+1)*empNumRecsPerFile; recordIndex++) {
        String record = String.format("{ \"emp_id\" : %d, \"emp_name\" : \"Employee %d\", \"dept_id\" : %d, \"mng_id\" : %d, \"some_id\" : %d }",
            recordIndex, recordIndex, recordIndex % NUM_DEPTS, recordIndex % NUM_MNGRS, recordIndex % NUM_IDS);
        printWriter.println(record);
      }
      printWriter.close();
    }

    // Table 2 consists of two columns "dept_id" and "dept_name"
    deptTableLocation = testTempFolder.newFolder().getAbsolutePath();

    // Write 4 records for each new file
    final int deptNumRecsPerFile = 4;
    for(int fileIndex=0; fileIndex<NUM_DEPTS/deptNumRecsPerFile; fileIndex++) {
      File file = new File(deptTableLocation + File.separator + fileIndex + ".json");
      PrintWriter printWriter = new PrintWriter(file);
      for (int recordIndex = fileIndex*deptNumRecsPerFile; recordIndex < (fileIndex+1)*deptNumRecsPerFile; recordIndex++) {
        String record = String.format("{ \"dept_id\" : %d, \"dept_name\" : \"Department %d\" }",
            recordIndex, recordIndex);
        printWriter.println(record);
      }
      printWriter.close();
    }

    // Initialize test queries
    groupByQuery = String.format("SELECT dept_id, count(*) as numEmployees FROM dfs.`%s` GROUP BY dept_id", empTableLocation);
    joinQuery = String.format("SELECT e.emp_name, d.dept_name FROM dfs.`%s` e JOIN dfs.`%s` d ON e.dept_id = d.dept_id",
        empTableLocation, deptTableLocation);

    // Generate and store output data for test queries. Used when verifying the output of queries ran using different
    // configurations.

    groupByQueryBaselineColumns = new String[] { "dept_id", "numEmployees" };

    groupByQueryBaselineValues = Lists.newArrayList();
    // group Id is generated based on expression 'recordIndex % NUM_DEPTS' above. 'recordIndex' runs from 0 to
    // NUM_EMPLOYEES, so we expect each number of occurrance of each dept_id to be NUM_EMPLOYEES/NUM_DEPTS (1000/40 =
    // 25)
    final int numOccurrances = NUM_EMPLOYEES/NUM_DEPTS;
    for(int i = 0; i < NUM_DEPTS; i++) {
      groupByQueryBaselineValues.add(new Object[] { (long)i, (long)numOccurrances});
    }

    joinQueryBaselineColumns = new String[] { "emp_name", "dept_name" };

    joinQueryBaselineValues = Lists.newArrayList();
    for(int i = 0; i < NUM_EMPLOYEES; i++) {
      final String employee = String.format("Employee %d", i);
      final String dept = String.format("Department %d", i % NUM_DEPTS);
      joinQueryBaselineValues.add(new String[] { employee, dept });
    }
  }

  public static void setupHelper(boolean isMuxOn, boolean isDeMuxOn) throws Exception {
    // set slice count to 1, so that we can have more parallelization for testing
    test("ALTER SESSION SET `planner.slice_target`=1");
    // disable the broadcast join to produce plans with HashToRandomExchanges.
    test("ALTER SESSION SET `planner.enable_broadcast_join`=false");
    test("ALTER SESSION SET `planner.enable_mux_exchange`=" + isMuxOn);
    test("ALTER SESSION SET `planner.enable_demux_exchange`=" + isDeMuxOn);
  }

  @Test
  public void testGroupByMultiFields() throws Exception {
    // Test multifield hash generation

    test("ALTER SESSION SET `planner.slice_target`=1");
    test("ALTER SESSION SET `planner.enable_mux_exchange`=" + true);
    test("ALTER SESSION SET `planner.enable_demux_exchange`=" + false);

    final String groupByMultipleQuery = String.format("SELECT dept_id, mng_id, some_id, count(*) as numEmployees FROM dfs.`%s` e GROUP BY dept_id, mng_id, some_id", empTableLocation);
    final String[] groupByMultipleQueryBaselineColumns = new String[] { "dept_id", "mng_id", "some_id", "numEmployees" };

    final int numOccurrances = NUM_EMPLOYEES/NUM_DEPTS;

    final String plan = getPlanInString("EXPLAIN PLAN FOR " + groupByMultipleQuery, JSON_FORMAT);
    System.out.println("Plan: " + plan);

    jsonExchangeOrderChecker(plan, false, 1, "castint\\(hash64asdouble\\(.*, hash64asdouble\\(.*, hash64asdouble\\(.*\\) \\) \\) \\) ");

    // Run the query and verify the output
    final TestBuilder testBuilder = testBuilder()
        .sqlQuery(groupByMultipleQuery)
        .unOrdered()
        .baselineColumns(groupByMultipleQueryBaselineColumns);

    for(int i = 0; i < NUM_DEPTS; i++) {
      testBuilder.baselineValues(new Object[] { (long)i, (long)0, (long)0, (long)numOccurrances});
    }

    testBuilder.go();
  }

  @Test
  public void testGroupBy_NoMux_NoDeMux() throws Exception {
    testGroupByHelper(false, false);
  }

  @Test
  public void testJoin_NoMux_NoDeMux() throws Exception {
    testJoinHelper(false, false);
  }

  @Test
  public void testGroupBy_Mux_NoDeMux() throws Exception {
    testGroupByHelper(true, false);
  }

  @Test
  public void testJoin_Mux_NoDeMux() throws Exception {
    testJoinHelper(true, false);
  }

  @Test
  public void testGroupBy_NoMux_DeMux() throws Exception {
    testGroupByHelper(false, true);
  }

  @Test
  public void testJoin_NoMux_DeMux() throws Exception {
    testJoinHelper(false, true);
  }

  @Test
  public void testGroupBy_Mux_DeMux() throws Exception {
    testGroupByHelper(true, true);
  }

  @Test
  public void testJoin_Mux_DeMux() throws Exception {
    testJoinHelper(true, true);
  }

  private static void testGroupByHelper(boolean isMuxOn, boolean isDeMuxOn) throws Exception {
    testHelper(isMuxOn, isDeMuxOn, groupByQuery,
        isMuxOn ? 1 : 0, isDeMuxOn ? 1 : 0,
        groupByQueryBaselineColumns, groupByQueryBaselineValues);
  }

  public static void testJoinHelper(boolean isMuxOn, boolean isDeMuxOn) throws Exception {
    testHelper(isMuxOn, isDeMuxOn, joinQuery,
        isMuxOn ? 2 : 0, isDeMuxOn ? 2 : 0,
        joinQueryBaselineColumns, joinQueryBaselineValues);
  }

  private static void testHelper(boolean isMuxOn, boolean isDeMuxOn, String query,
      int expectedNumMuxes, int expectedNumDeMuxes, String[] baselineColumns, List<Object[]> baselineValues)
      throws Exception {
    setupHelper(isMuxOn, isDeMuxOn);

    String plan = getPlanInString("EXPLAIN PLAN FOR " + query, JSON_FORMAT);
    System.out.println("Plan: " + plan);

    if ( isMuxOn ) {
      // # of hash exchanges should be = # of mux exchanges + # of demux exchanges
      assertEquals("HashExpr on the hash column should not happen", 2*expectedNumMuxes+expectedNumDeMuxes, StringUtils.countMatches(plan, HASH_EXPR_NAME));
      jsonExchangeOrderChecker(plan, isDeMuxOn, expectedNumMuxes, "castint\\(hash64asdouble\\(.*\\) \\) ");
    } else {
      assertEquals("HashExpr on the hash column should not happen", 0, StringUtils.countMatches(plan, HASH_EXPR_NAME));
    }

    // Make sure the plan has mux and demux exchanges (TODO: currently testing is rudimentary,
    // need to move it to sophisticated testing once we have better planning test tools are available)
    assertEquals("Wrong number of MuxExchanges are present in the plan",
        expectedNumMuxes, StringUtils.countMatches(plan, MUX_EXCHANGE));

    assertEquals("Wrong number of DeMuxExchanges are present in the plan",
        expectedNumDeMuxes, StringUtils.countMatches(plan, DEMUX_EXCHANGE));

    // Run the query and verify the output
    TestBuilder testBuilder = testBuilder()
        .sqlQuery(query)
        .unOrdered()
        .baselineColumns(baselineColumns);

    for(Object[] baselineRecord : baselineValues) {
      testBuilder.baselineValues(baselineRecord);
    }

    testBuilder.go();

    testHelperVerifyPartitionSenderParallelization(plan, isMuxOn, isDeMuxOn);
  }

  private static void jsonExchangeOrderChecker(String plan, boolean isDemuxEnabled, int expectedNumMuxes, String hashExprPattern) throws Exception {
    final JSONObject planObj = (JSONObject) new JSONParser().parse(plan);
    assertNotNull("Corrupted query plan: null", planObj);
    final JSONArray graphArray = (JSONArray) planObj.get("graph");
    assertNotNull("No graph array present", graphArray);
    int i = 0;
    int k = 0;
    int prevExprsArraySize = 0;
    boolean foundExpr = false;
    int muxesCount = 0;
    for (Object object : graphArray) {
      final JSONObject popObj = (JSONObject) object;
      if ( popObj.containsKey("pop") && popObj.get("pop").equals("project")) {
        if ( popObj.containsKey("exprs")) {
          final JSONArray exprsArray = (JSONArray) popObj.get("exprs");
          for (Object exprObj : exprsArray) {
            final JSONObject expr = (JSONObject) exprObj;
            if ( expr.containsKey("ref") && expr.get("ref").equals("`"+ HASH_EXPR_NAME +"`")) {
              // found a match. Let's see if next one is the one we need
              final String hashField = (String) expr.get("expr");
              assertNotNull("HashExpr field can not be null", hashField);
              assertTrue("HashExpr field does not match pattern",hashField.matches(hashExprPattern));
              k = i;
              foundExpr = true;
              muxesCount++;
              break;
            }
          }
          if ( foundExpr ) {
            // will be reset to prevExprsArraySize-1 on the last project of the whole stanza
            prevExprsArraySize = exprsArray.size();
          }
        }
      }
      if ( !foundExpr ) {
        continue;
      }
      // next after project with hashexpr
      if ( k == i-1) {
        assertTrue("UnorderedMux should follow Project with HashExpr",
            popObj.containsKey("pop") && popObj.get("pop").equals(MUX_EXCHANGE_CONST));
      }
      if ( k == i-2) {
        assertTrue("HashToRandomExchange should follow UnorderedMux which should follow Project with HashExpr",
            popObj.containsKey("pop") && popObj.get("pop").equals(HASH_EXCHANGE));
        // is HashToRandom is using HashExpr
        assertTrue("HashToRandomExchnage should use hashExpr",
            popObj.containsKey("expr") && popObj.get("expr").equals("`"+ HASH_EXPR_NAME +"`"));
      }
      // if Demux is enabled it also should use HashExpr
      if ( isDemuxEnabled && k == i-3) {
        assertTrue("UnorderdDemuxExchange should follow HashToRandomExchange",
            popObj.containsKey("pop") && popObj.get("pop").equals(DEMUX_EXCHANGE_CONST));
        // is HashToRandom is using HashExpr
        assertTrue("UnorderdDemuxExchange should use hashExpr",
            popObj.containsKey("expr") && popObj.get("expr").equals("`"+HASH_EXPR_NAME +"`"));
      }
      if ( (isDemuxEnabled && k == i-4) || (!isDemuxEnabled && k == i-3) ) {
        // it should be a project without hashexpr, check if number of exprs is 1 less then in first project
        assertTrue("Should be project without hashexpr", popObj.containsKey("pop") && popObj.get("pop").equals("project"));
        final JSONArray exprsArray = (JSONArray) popObj.get("exprs");
        assertNotNull("Project should have some fields", exprsArray);
        assertEquals("Number of fields in closing project should be one less then in starting project",
            prevExprsArraySize, exprsArray.size());

        // Now let's reset all the counters, flags if we are going to have another batch of those exchanges
        k = 0;
        foundExpr = false;
        prevExprsArraySize = 0;
      }
      i++;
    }
    assertEquals("Number of Project/Mux/HashExchange/... ", expectedNumMuxes, muxesCount);
  }

  // Verify the number of partition senders in a major fragments is not more than the cluster size and each endpoint
  // in the cluster has at most one fragment from a given major fragment that has the partition sender.
  private static void testHelperVerifyPartitionSenderParallelization(
      String plan, boolean isMuxOn, boolean isDeMuxOn) throws Exception {

    final DrillbitContext drillbitContext = getDrillbitContext();
    final PhysicalPlanReader planReader = drillbitContext.getPlanReader();
    final Fragment rootFragment = PopUnitTestBase.getRootFragmentFromPlanString(planReader, plan);

    final List<Integer> deMuxFragments = Lists.newLinkedList();
    final List<Integer> htrFragments = Lists.newLinkedList();
    final PlanningSet planningSet = new PlanningSet();

    // Create a planningSet to get the assignment of major fragment ids to fragments.
    PARALLELIZER.initFragmentWrappers(rootFragment, planningSet);

    findFragmentsWithPartitionSender(rootFragment, planningSet, deMuxFragments, htrFragments);

    final QueryContextInformation queryContextInfo = Utilities.createQueryContextInfo("dummySchemaName");
    QueryWorkUnit qwu = PARALLELIZER.getFragments(new OptionList(), drillbitContext.getEndpoint(),
        QueryId.getDefaultInstance(),
        drillbitContext.getBits(), planReader, rootFragment, USER_SESSION, queryContextInfo);

    // Make sure the number of minor fragments with HashPartitioner within a major fragment is not more than the
    // number of Drillbits in cluster
    ArrayListMultimap<Integer, DrillbitEndpoint> partitionSenderMap = ArrayListMultimap.create();
    for(PlanFragment planFragment : qwu.getFragments()) {
      if (planFragment.getFragmentJson().contains("hash-partition-sender")) {
        int majorFragmentId = planFragment.getHandle().getMajorFragmentId();
        DrillbitEndpoint assignedEndpoint = planFragment.getAssignment();
        partitionSenderMap.get(majorFragmentId).add(assignedEndpoint);
      }
    }

    if (isMuxOn) {
      verifyAssignment(htrFragments, partitionSenderMap);
    }

    if (isDeMuxOn) {
      verifyAssignment(deMuxFragments, partitionSenderMap);
    }
  }

  /**
   * Helper method to find the major fragment ids of fragments that have PartitionSender.
   * A fragment can have PartitionSender if sending exchange of the current fragment is a
   *   1. DeMux Exchange -> goes in deMuxFragments
   *   2. HashToRandomExchange -> goes into htrFragments
   */
  private static void findFragmentsWithPartitionSender(Fragment currentRootFragment, PlanningSet planningSet,
      List<Integer> deMuxFragments, List<Integer> htrFragments) {

    if (currentRootFragment != null) {
      final Exchange sendingExchange = currentRootFragment.getSendingExchange();
      if (sendingExchange != null) {
        final int majorFragmentId = planningSet.get(currentRootFragment).getMajorFragmentId();
        if (sendingExchange instanceof UnorderedDeMuxExchange) {
          deMuxFragments.add(majorFragmentId);
        } else if (sendingExchange instanceof HashToRandomExchange) {
          htrFragments.add(majorFragmentId);
        }
      }

      for(ExchangeFragmentPair e : currentRootFragment.getReceivingExchangePairs()) {
        findFragmentsWithPartitionSender(e.getNode(), planningSet, deMuxFragments, htrFragments);
      }
    }
  }

  /** Helper method to verify the number of PartitionSenders in a given fragment endpoint assignments */
  private static void verifyAssignment(List<Integer> fragmentList,
      ArrayListMultimap<Integer, DrillbitEndpoint> partitionSenderMap) {

    // We expect at least one entry the list
    assertTrue(fragmentList.size() > 0);

    for(Integer majorFragmentId : fragmentList) {
      // we expect the fragment that has DeMux/HashToRandom as sending exchange to have parallelization with not more
      // than the number of nodes in the cluster and each node in the cluster can have at most one assignment
      List<DrillbitEndpoint> assignments = partitionSenderMap.get(majorFragmentId);
      assertNotNull(assignments);
      assertTrue(assignments.size() > 0);
      assertTrue(String.format("Number of partition senders in major fragment [%d] is more than expected", majorFragmentId), CLUSTER_SIZE >= assignments.size());

      // Make sure there are no duplicates in assigned endpoints (i.e at most one partition sender per endpoint)
      assertTrue("Some endpoints have more than one fragment that has ParitionSender", ImmutableSet.copyOf(assignments).size() == assignments.size());
    }
  }

  @AfterClass
  public static void cleanupTempFolder() throws IOException {
    testTempFolder.delete();
  }
}