/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.mapreduce.combine;

import java.io.IOException;
import java.util.Iterator;

import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapreduce.TaskCounter;
import org.apache.hadoop.util.Progress;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.mapreduce.hadoop.MRJobConfig;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.TaskContext;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.common.sort.impl.IFile.Writer;
import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
import org.junit.Test;
import org.mockito.Mockito;

import static org.junit.Assert.assertEquals;

public class TestMRCombiner {

  @Test
  public void testRunOldCombiner() throws IOException, InterruptedException {
    TezConfiguration conf = new TezConfiguration();
    setKeyAndValueClassTypes(conf);
    conf.setClass("mapred.combiner.class", OldReducer.class, Object.class);
    TaskContext taskContext = getTaskContext(conf);
    MRCombiner combiner = new MRCombiner(taskContext);
    Writer writer = Mockito.mock(Writer.class);
    combiner.combine(new TezRawKeyValueIteratorTest(), writer);
    long inputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_INPUT_RECORDS).getValue();
    long outputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_OUTPUT_RECORDS).getValue();
    assertEquals(6, inputRecords);
    assertEquals(3, outputRecords);
    // verify combiner output keys and values
    verifyKeyAndValues(writer);
  }

  @Test
  public void testRunNewCombiner() throws IOException, InterruptedException {
    TezConfiguration conf = new TezConfiguration();
    setKeyAndValueClassTypes(conf);
    conf.setBoolean("mapred.mapper.new-api", true);
    conf.setClass(MRJobConfig.COMBINE_CLASS_ATTR, NewReducer.class,
        Object.class);
    TaskContext taskContext = getTaskContext(conf);
    MRCombiner combiner = new MRCombiner(taskContext);
    Writer writer = Mockito.mock(Writer.class);
    combiner.combine(new TezRawKeyValueIteratorTest(), writer);
    long inputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_INPUT_RECORDS).getValue();
    long outputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_OUTPUT_RECORDS).getValue();
    assertEquals(6, inputRecords);
    assertEquals(3, outputRecords);
    // verify combiner output keys and values
    verifyKeyAndValues(writer);
  }

  @Test
  public void testTop2RunOldCombiner() throws IOException, InterruptedException {
    TezConfiguration conf = new TezConfiguration();
    setKeyAndValueClassTypes(conf);
    conf.setClass("mapred.combiner.class", Top2OldReducer.class, Object.class);
    TaskContext taskContext = getTaskContext(conf);
    MRCombiner combiner = new MRCombiner(taskContext);
    Writer writer = Mockito.mock(Writer.class);
    combiner.combine(new TezRawKeyValueIteratorTest(), writer);
    long inputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_INPUT_RECORDS).getValue();
    long outputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_OUTPUT_RECORDS).getValue();
    assertEquals(6, inputRecords);
    assertEquals(5, outputRecords);
  }

  @Test
  public void testTop2RunNewCombiner() throws IOException, InterruptedException {
    TezConfiguration conf = new TezConfiguration();
    setKeyAndValueClassTypes(conf);
    conf.setBoolean("mapred.mapper.new-api", true);
    conf.setClass(MRJobConfig.COMBINE_CLASS_ATTR, Top2NewReducer.class,
        Object.class);
    TaskContext taskContext = getTaskContext(conf);
    MRCombiner combiner = new MRCombiner(taskContext);
    Writer writer = Mockito.mock(Writer.class);
    combiner.combine(new TezRawKeyValueIteratorTest(), writer);
    long inputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_INPUT_RECORDS).getValue();
    long outputRecords = taskContext.getCounters().findCounter(TaskCounter.COMBINE_OUTPUT_RECORDS).getValue();
    assertEquals(6, inputRecords);
    assertEquals(5, outputRecords);
  }

  private void setKeyAndValueClassTypes(TezConfiguration conf) {
    conf.setClass(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS,
        Text.class, Object.class);
    conf.setClass(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS,
        IntWritable.class, Object.class);
  }

  private TaskContext getTaskContext(TezConfiguration conf)
      throws IOException {
    UserPayload payload = TezUtils.createUserPayloadFromConf(conf);
    TaskContext taskContext = Mockito.mock(InputContext.class);
    Mockito.when(taskContext.getUserPayload()).thenReturn(payload);
    Mockito.when(taskContext.getCounters()).thenReturn(new TezCounters());
    Mockito.when(taskContext.getApplicationId()).thenReturn(
        ApplicationId.newInstance(123456, 1));
    return taskContext;
  }

  private void verifyKeyAndValues(Writer writer) throws IOException {
    Mockito.verify(writer, Mockito.atLeastOnce()).append(new Text("tez"),
        new IntWritable(3));
    Mockito.verify(writer, Mockito.atLeastOnce()).append(new Text("apache"),
        new IntWritable(1));
    Mockito.verify(writer, Mockito.atLeastOnce()).append(new Text("hadoop"),
        new IntWritable(2));
  }

  private static class TezRawKeyValueIteratorTest implements
      TezRawKeyValueIterator {

    private int i = -1;
    private String[] keys = { "tez", "tez", "tez", "apache", "hadoop", "hadoop" };

    @Override
    public boolean next() throws IOException {
      boolean hasNext = hasNext();
      if (hasNext) {
        i += 1;
      }

      return hasNext;
    }

    public boolean hasNext() throws IOException {
      if (i < (keys.length -  1)) {
        return true;
      }
      return false;
    }

    @Override
    public DataInputBuffer getValue() throws IOException {
      DataInputBuffer value = new DataInputBuffer();
      IntWritable intValue = new IntWritable(1);
      DataOutputBuffer out = new DataOutputBuffer();
      intValue.write(out);
      value.reset(out.getData(), out.getLength());
      return value;
    }

    @Override
    public Progress getProgress() {
      return null;
    }

    @Override
    public boolean isSameKey() throws IOException {
      return false;
    }

    @Override
    public DataInputBuffer getKey() throws IOException {
      DataInputBuffer key = new DataInputBuffer();
      Text text = new Text(keys[i]);
      DataOutputBuffer out = new DataOutputBuffer();
      text.write(out);
      key.reset(out.getData(), out.getLength());
      return key;
    }

    @Override
    public void close() throws IOException {
    }
  }

  private static class OldReducer implements
      Reducer<Text, IntWritable, Text, IntWritable> {

    @Override
    public void configure(JobConf arg0) {
    }

    @Override
    public void close() throws IOException {
    }

    @Override
    public void reduce(Text key, Iterator<IntWritable> value,
        OutputCollector<Text, IntWritable> collector, Reporter reporter)
        throws IOException {
      int count = 0;
      while (value.hasNext()) {
        count += value.next().get();
      }
      collector.collect(new Text(key.toString()), new IntWritable(count));
    }
  }

  private static class NewReducer extends
      org.apache.hadoop.mapreduce.Reducer<Text, IntWritable, Text, IntWritable> {
    @Override
    protected void reduce(Text key, Iterable<IntWritable> values,
        Context context) throws IOException, InterruptedException {
      int count = 0;
      for (IntWritable value : values) {
        count += value.get();
      }
      context.write(new Text(key.toString()), new IntWritable(count));
    }
  }

  private static class Top2OldReducer extends OldReducer {
    @Override
    public void reduce(Text key, Iterator<IntWritable> value,
        OutputCollector<Text, IntWritable> collector, Reporter reporter)
        throws IOException {
      int i = 0;
      while (value.hasNext()) {
        int val = value.next().get();
        if (i++ < 2) {
          collector.collect(new Text(key.toString()), new IntWritable(val));
        }
      }
    }
  }

  private static class Top2NewReducer extends NewReducer {
    @Override
    protected void reduce(Text key, Iterable<IntWritable> values,
        Context context) throws IOException, InterruptedException {
      int i = 0;
      for (IntWritable value : values) {
        if (i++ < 2) {
          context.write(new Text(key.toString()), value);
        } else {
          break;
        }
      }
    }
  }
}