/*
 * Copyright (c) 2015-2019, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.labs.envelope.derive;

import com.cloudera.labs.envelope.spark.Contexts;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigFactory;
import org.apache.hadoop.util.hash.Hash;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.junit.Test;

import java.util.Map;

import static com.cloudera.labs.envelope.validate.ValidationAssert.assertNoValidationFailures;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class TestHashDeriver {

  @Test
  public void testDefaultHash() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    Dataset<Row> derived = d.derive(dependencies);

    assertEquals(1, derived.count());
    assertEquals(testDataFrame().schema().size() + 1, derived.schema().size());
    assertTrue(Lists.newArrayList(derived.schema().fieldNames()).contains(HashDeriver.DEFAULT_HASH_FIELD_NAME));
    assertEquals(
        "4891a9d87f8f46a5c8c19c3059864146", // all hashes in this class generated by 'md5' CLI command
        derived.collectAsList().get(0).get(derived.schema().size() - 1));
  }

  @Test (expected = RuntimeException.class)
  public void testMissingDependency() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    configMap.put(HashDeriver.STEP_NAME_CONFIG, "dep2");
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    d.derive(dependencies);
  }

  @Test
  public void testSpecifiedDependency() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    configMap.put(HashDeriver.STEP_NAME_CONFIG, "dep1");
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    Dataset<Row> derived = d.derive(dependencies);

    assertEquals(1, derived.count());
    assertEquals(testDataFrame().schema().size() + 1, derived.schema().size());
    assertTrue(Lists.newArrayList(derived.schema().fieldNames()).contains(HashDeriver.DEFAULT_HASH_FIELD_NAME));
    assertEquals(
        "4891a9d87f8f46a5c8c19c3059864146",
        derived.collectAsList().get(0).get(derived.schema().size() - 1));
  }

  @Test (expected = RuntimeException.class)
  public void testCantUseDefaultDependency() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());
    dependencies.put("dep2", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    d.derive(dependencies);
  }

  @Test
  public void testCustomDelimiterHash() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    configMap.put(HashDeriver.DELIMITER_CONFIG, ":::");
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    Dataset<Row> derived = d.derive(dependencies);

    assertEquals(1, derived.count());
    assertEquals(testDataFrame().schema().size() + 1, derived.schema().size());
    assertTrue(Lists.newArrayList(derived.schema().fieldNames()).contains(HashDeriver.DEFAULT_HASH_FIELD_NAME));
    assertEquals(
        "d85bcfeacb088fcc7e8ded019ed48ec2",
        derived.collectAsList().get(0).get(derived.schema().size() - 1));
  }

  @Test
  public void testCustomNullStringHash() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    configMap.put(HashDeriver.NULL_STRING_CONFIG, "");
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    Dataset<Row> derived = d.derive(dependencies);

    assertEquals(1, derived.count());
    assertEquals(testDataFrame().schema().size() + 1, derived.schema().size());
    assertTrue(Lists.newArrayList(derived.schema().fieldNames()).contains(HashDeriver.DEFAULT_HASH_FIELD_NAME));
    assertEquals(
        "862ff0dc2acce97b6f8bd6c369df2668",
        derived.collectAsList().get(0).get(derived.schema().size() - 1));
  }

  @Test
  public void testIncludeFieldsHash() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    configMap.put(HashDeriver.INCLUDE_FIELDS_CONFIG, Lists.newArrayList("c1", "c2"));
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    Dataset<Row> derived = d.derive(dependencies);

    assertEquals(1, derived.count());
    assertEquals(testDataFrame().schema().size() + 1, derived.schema().size());
    assertTrue(Lists.newArrayList(derived.schema().fieldNames()).contains(HashDeriver.DEFAULT_HASH_FIELD_NAME));
    assertEquals(
        "203ad5ffa1d7c650ad681fdff3965cd2",
        derived.collectAsList().get(0).get(derived.schema().size() - 1));
  }

  @Test
  public void testExcludeFieldsHash() {
    Map<String, Dataset<Row>> dependencies = Maps.newHashMap();
    dependencies.put("dep1", testDataFrame());

    Map<String, Object> configMap = Maps.newHashMap();
    configMap.put(HashDeriver.EXCLUDE_FIELDS_CONFIG, Lists.newArrayList("c1", "c2"));
    Config config = ConfigFactory.parseMap(configMap);

    HashDeriver d = new HashDeriver();
    assertNoValidationFailures(d, config);
    d.configure(config);

    Dataset<Row> derived = d.derive(dependencies);

    assertEquals(1, derived.count());
    assertEquals(testDataFrame().schema().size() + 1, derived.schema().size());
    assertTrue(Lists.newArrayList(derived.schema().fieldNames()).contains(HashDeriver.DEFAULT_HASH_FIELD_NAME));
    assertEquals(
        "71fb4bf2f54627f64c60ca5e396d1ccc",
        derived.collectAsList().get(0).get(derived.schema().size() - 1));
  }

  private Dataset<Row> testDataFrame() {
    return Contexts.getSparkSession().createDataFrame(
        Lists.newArrayList(RowFactory.create("hello", 1, null, -1.0f, true)),
        DataTypes.createStructType(Lists.newArrayList(
            DataTypes.createStructField("c1", DataTypes.StringType, true),
            DataTypes.createStructField("c2", DataTypes.IntegerType, true),
            DataTypes.createStructField("c3", DataTypes.BinaryType, true),
            DataTypes.createStructField("c4", DataTypes.FloatType, true),
            DataTypes.createStructField("c5", DataTypes.BooleanType, true)
        ))
    );
  }

}