/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.mapreduce;

import junit.framework.TestCase;

import java.io.IOException;
import java.io.DataInput;
import java.io.DataOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.lib.output.NullOutputFormat;
import org.apache.hadoop.mapreduce.MRConfig;
import org.apache.hadoop.util.ReflectionUtils;

public class TestMapCollection {

  private static final Log LOG = LogFactory.getLog(
      TestMapCollection.class.getName());

  public static abstract class FillWritable implements Writable, Configurable {
    private int len;
    protected boolean disableRead;
    private byte[] b;
    private final Random r;
    protected final byte fillChar;
    public FillWritable(byte fillChar) {
      this.fillChar = fillChar;
      r = new Random();
      final long seed = r.nextLong();
      LOG.info("seed: " + seed);
      r.setSeed(seed);
    }
    @Override
    public Configuration getConf() {
      return null;
    }
    public void setLength(int len) {
      this.len = len;
    }
    public int compareTo(FillWritable o) {
      if (o == this) return 0;
      return len - o.len;
    }
    @Override
    public int hashCode() {
      return 37 * len;
    }
    @Override
    public boolean equals(Object o) {
      if (!(o instanceof FillWritable)) return false;
      return 0 == compareTo((FillWritable)o);
    }
    @Override
    public void readFields(DataInput in) throws IOException {
      if (disableRead) {
        return;
      }
      len = WritableUtils.readVInt(in);
      for (int i = 0; i < len; ++i) {
        assertEquals("Invalid byte at " + i, fillChar, in.readByte());
      }
    }
    @Override
    public void write(DataOutput out) throws IOException {
      if (0 == len) {
        return;
      }
      int written = 0;
      if (!disableRead) {
        WritableUtils.writeVInt(out, len);
        written -= WritableUtils.getVIntSize(len);
      }
      if (len > 1024) {
        if (null == b || b.length < len) {
          b = new byte[2 * len];
        }
        Arrays.fill(b, fillChar);
        do {
          final int write = Math.min(len - written, r.nextInt(len));
          out.write(b, 0, write);
          written += write;
        } while (written < len);
        assertEquals(len, written);
      } else {
        for (int i = written; i < len; ++i) {
          out.write(fillChar);
        }
      }
    }
  }

  public static class KeyWritable
    extends FillWritable implements WritableComparable<FillWritable> {

    static final byte keyFill = (byte)('K' & 0xFF);
    public KeyWritable() {
      super(keyFill);
    }
    @Override
    public void setConf(Configuration conf) {
      disableRead = conf.getBoolean("test.disable.key.read", false);
    }
  }

  public static class ValWritable extends FillWritable {
    public ValWritable() {
      super((byte)('V' & 0xFF));
    }
    @Override
    public void setConf(Configuration conf) {
      disableRead = conf.getBoolean("test.disable.val.read", false);
    }
  }

  public static class VariableComparator
      implements RawComparator<KeyWritable>, Configurable {
    private boolean readLen;
    public VariableComparator() { }
    @Override
    public void setConf(Configuration conf) {
      readLen = !conf.getBoolean("test.disable.key.read", false);
    }
    @Override
    public Configuration getConf() { return null; }
    public int compare(KeyWritable k1, KeyWritable k2) {
      return k1.compareTo(k2);
    }
    @Override
    public int compare(byte[] b1, int s1, int l1,
                       byte[] b2, int s2, int l2) {
      final int n1;
      final int n2;
      if (readLen) {
        n1 = WritableUtils.decodeVIntSize(b1[s1]);
        n2 = WritableUtils.decodeVIntSize(b2[s2]);
      } else {
        n1 = 0;
        n2 = 0;
      }
      for (int i = s1 + n1; i < l1 - n1; ++i) {
        assertEquals("Invalid key at " + s1, (int)KeyWritable.keyFill, b1[i]);
      }
      for (int i = s2 + n2; i < l2 - n2; ++i) {
        assertEquals("Invalid key at " + s2, (int)KeyWritable.keyFill, b2[i]);
      }
      return l1 - l2;
    }
  }

  public static class SpillReducer
      extends Reducer<KeyWritable,ValWritable,NullWritable,NullWritable> {

    private int numrecs;
    private int expected;

    @Override
    protected void setup(Context job) {
      numrecs = 0;
      expected = job.getConfiguration().getInt("test.spillmap.records", 100);
    }

    @Override
    protected void reduce(KeyWritable k, Iterable<ValWritable> values,
        Context context) throws IOException, InterruptedException {
      for (ValWritable val : values) {
        ++numrecs;
      }
    }
    @Override
    protected void cleanup(Context context)
        throws IOException, InterruptedException {
      assertEquals("Unexpected record count", expected, numrecs);
    }
  }

  public static class FakeSplit extends InputSplit implements Writable {
    @Override
    public void write(DataOutput out) throws IOException { }
    @Override
    public void readFields(DataInput in) throws IOException { }
    @Override
    public long getLength() { return 0L; }
    @Override
    public String[] getLocations() { return new String[0]; }
  }

  public abstract static class RecordFactory implements Configurable {
    public Configuration getConf() { return null; }
    public abstract int keyLen(int i);
    public abstract int valLen(int i);
  }

  public static class FixedRecordFactory extends RecordFactory {
    private int keylen;
    private int vallen;
    public FixedRecordFactory() { }
    public void setConf(Configuration conf) {
      keylen = conf.getInt("test.fixedrecord.keylen", 0);
      vallen = conf.getInt("test.fixedrecord.vallen", 0);
    }
    public int keyLen(int i) { return keylen; }
    public int valLen(int i) { return vallen; }
    public static void setLengths(Configuration conf, int keylen, int vallen) {
      conf.setInt("test.fixedrecord.keylen", keylen);
      conf.setInt("test.fixedrecord.vallen", vallen);
      conf.setBoolean("test.disable.key.read", 0 == keylen);
      conf.setBoolean("test.disable.val.read", 0 == vallen);
    }
  }

  public static class FakeIF extends InputFormat<KeyWritable,ValWritable> {

    public FakeIF() { }

    @Override
    public List<InputSplit> getSplits(JobContext ctxt) throws IOException {
      final int numSplits = ctxt.getConfiguration().getInt(
          "test.mapcollection.num.maps", -1);
      List<InputSplit> splits = new ArrayList<InputSplit>(numSplits);
      for (int i = 0; i < numSplits; ++i) {
        splits.add(i, new FakeSplit());
      }
      return splits;
    }

    public RecordReader<KeyWritable,ValWritable> createRecordReader(
        InputSplit ignored, final TaskAttemptContext taskContext) {
      return new RecordReader<KeyWritable,ValWritable>() {
        private RecordFactory factory;
        private final KeyWritable key = new KeyWritable();
        private final ValWritable val = new ValWritable();
        private int current;
        private int records;
        @Override
        public void initialize(InputSplit split, TaskAttemptContext context) {
          final Configuration conf = context.getConfiguration();
          key.setConf(conf);
          val.setConf(conf);
          factory = ReflectionUtils.newInstance(
              conf.getClass("test.mapcollection.class",
                FixedRecordFactory.class, RecordFactory.class), conf);
          assertNotNull(factory);
          current = 0;
          records = conf.getInt("test.spillmap.records", 100);
        }
        @Override
        public boolean nextKeyValue() {
          key.setLength(factory.keyLen(current));
          val.setLength(factory.valLen(current));
          return current++ < records;
        }
        @Override
        public KeyWritable getCurrentKey() { return key; }
        @Override
        public ValWritable getCurrentValue() { return val; }
        @Override
        public float getProgress() { return (float) current / records; }
        @Override
        public void close() {
          assertEquals("Unexpected count", records, current - 1);
        }
      };
    }
  }

  private static void runTest(String name, int keylen, int vallen,
      int records, int ioSortMB, float spillPer)
      throws Exception {
    Configuration conf = new Configuration();
    conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100);
    Job job = Job.getInstance(conf);
    conf = job.getConfiguration();
    conf.setInt(MRJobConfig.IO_SORT_MB, ioSortMB);
    conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(spillPer));
    conf.setClass("test.mapcollection.class", FixedRecordFactory.class,
        RecordFactory.class);
    FixedRecordFactory.setLengths(conf, keylen, vallen);
    conf.setInt("test.spillmap.records", records);
    runTest(name, job);
  }

  private static void runTest(String name, Job job) throws Exception {
    job.setNumReduceTasks(1);
    job.getConfiguration().set(MRConfig.FRAMEWORK_NAME, MRConfig.LOCAL_FRAMEWORK_NAME);
    job.getConfiguration().setInt(MRJobConfig.IO_SORT_FACTOR, 1000);
    job.getConfiguration().set("fs.defaultFS", "file:///");
    job.getConfiguration().setInt("test.mapcollection.num.maps", 1);
    job.setInputFormatClass(FakeIF.class);
    job.setOutputFormatClass(NullOutputFormat.class);
    job.setMapperClass(Mapper.class);
    job.setReducerClass(SpillReducer.class);
    job.setMapOutputKeyClass(KeyWritable.class);
    job.setMapOutputValueClass(ValWritable.class);
    job.setSortComparatorClass(VariableComparator.class);

    LOG.info("Running " + name);
    assertTrue("Job failed!", job.waitForCompletion(false));
  }

  @Test
  public void testValLastByte() throws Exception {
    // last byte of record/key is the last/first byte in the spill buffer
    runTest("vallastbyte", 128, 896, 1344, 1, 0.5f);
    runTest("keylastbyte", 512, 1024, 896, 1, 0.5f);
  }

  @Test
  public void testLargeRecords() throws Exception {
    // maps emitting records larger than mapreduce.task.io.sort.mb
    runTest("largerec", 100, 1024*1024, 5, 1, .8f);
    runTest("largekeyzeroval", 1024*1024, 0, 5, 1, .8f);
  }

  @Test
  public void testSpillPer2B() throws Exception {
    // set non-default, 100% speculative spill boundary
    runTest("fullspill2B", 1, 1, 10000, 1, 1.0f);
    runTest("fullspill200B", 100, 100, 10000, 1, 1.0f);
    runTest("fullspillbuf", 10 * 1024, 20 * 1024, 256, 1, 1.0f);
    runTest("lt50perspill", 100, 100, 10000, 1, 0.3f);
  }

  @Test
  public void testZeroVal() throws Exception {
    // test key/value at zero-length
    runTest("zeroval", 1, 0, 10000, 1, .8f);
    runTest("zerokey", 0, 1, 10000, 1, .8f);
    runTest("zerokeyval", 0, 0, 10000, 1, .8f);
    runTest("zerokeyvalfull", 0, 0, 10000, 1, 1.0f);
  }

  @Test
  public void testSingleRecord() throws Exception {
    runTest("singlerecord", 100, 100, 1, 1, 1.0f);
    runTest("zerokeyvalsingle", 0, 0, 1, 1, 1.0f);
  }

  @Test
  public void testLowSpill() throws Exception {
    runTest("lowspill", 4000, 96, 20, 1, 0.00390625f);
  }

  @Test
  public void testSplitMetaSpill() throws Exception {
    runTest("splitmetaspill", 7, 1, 131072, 1, 0.8f);
  }

  public static class StepFactory extends RecordFactory {
    public int prekey;
    public int postkey;
    public int preval;
    public int postval;
    public int steprec;
    public void setConf(Configuration conf) {
      prekey = conf.getInt("test.stepfactory.prekey", 0);
      postkey = conf.getInt("test.stepfactory.postkey", 0);
      preval = conf.getInt("test.stepfactory.preval", 0);
      postval = conf.getInt("test.stepfactory.postval", 0);
      steprec = conf.getInt("test.stepfactory.steprec", 0);
    }
    public static void setLengths(Configuration conf, int prekey, int postkey,
        int preval, int postval, int steprec) {
      conf.setInt("test.stepfactory.prekey", prekey);
      conf.setInt("test.stepfactory.postkey", postkey);
      conf.setInt("test.stepfactory.preval", preval);
      conf.setInt("test.stepfactory.postval", postval);
      conf.setInt("test.stepfactory.steprec", steprec);
    }
    public int keyLen(int i) {
      return i > steprec ? postkey : prekey;
    }
    public int valLen(int i) {
      return i > steprec ? postval : preval;
    }
  }

  @Test
  public void testPostSpillMeta() throws Exception {
    // write larger records until spill, then write records that generate
    // no writes into the serialization buffer
    Configuration conf = new Configuration();
    conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100);
    Job job = Job.getInstance(conf);
    conf = job.getConfiguration();
    conf.setInt(MRJobConfig.IO_SORT_MB, 1);
    // 2^20 * spill = 14336 bytes available post-spill, at most 896 meta
    conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(.986328125f));
    conf.setClass("test.mapcollection.class", StepFactory.class,
        RecordFactory.class);
    StepFactory.setLengths(conf, 4000, 0, 96, 0, 252);
    conf.setInt("test.spillmap.records", 1000);
    conf.setBoolean("test.disable.key.read", true);
    conf.setBoolean("test.disable.val.read", true);
    runTest("postspillmeta", job);
  }

  @Test
  public void testLargeRecConcurrent() throws Exception {
    Configuration conf = new Configuration();
    conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100);
    Job job = Job.getInstance(conf);
    conf = job.getConfiguration();
    conf.setInt(MRJobConfig.IO_SORT_MB, 1);
    conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(.986328125f));
    conf.setClass("test.mapcollection.class", StepFactory.class,
        RecordFactory.class);
    StepFactory.setLengths(conf, 4000, 261120, 96, 1024, 251);
    conf.setInt("test.spillmap.records", 255);
    conf.setBoolean("test.disable.key.read", false);
    conf.setBoolean("test.disable.val.read", false);
    runTest("largeconcurrent", job);
  }

  public static class RandomFactory extends RecordFactory {
    public int minkey;
    public int maxkey;
    public int minval;
    public int maxval;
    private final Random r = new Random();
    private static int nextRand(Random r, int max) {
      return (int)Math.exp(r.nextDouble() * Math.log(max));
    }
    public void setConf(Configuration conf) {
      r.setSeed(conf.getLong("test.randomfactory.seed", 0L));
      minkey = conf.getInt("test.randomfactory.minkey", 0);
      maxkey = conf.getInt("test.randomfactory.maxkey", 0) - minkey;
      minval = conf.getInt("test.randomfactory.minval", 0);
      maxval = conf.getInt("test.randomfactory.maxval", 0) - minval;
    }
    public static void setLengths(Configuration conf, Random r, int max) {
      int k1 = nextRand(r, max);
      int k2 = nextRand(r, max);
      if (k1 > k2) {
        final int tmp = k1;
        k1 = k2;
        k2 = k1;
      }
      int v1 = nextRand(r, max);
      int v2 = nextRand(r, max);
      if (v1 > v2) {
        final int tmp = v1;
        v1 = v2;
        v2 = v1;
      }
      setLengths(conf, k1, ++k2, v1, ++v2);
    }
    public static void setLengths(Configuration conf, int minkey, int maxkey,
        int minval, int maxval) {
      assert minkey < maxkey;
      assert minval < maxval;
      conf.setInt("test.randomfactory.minkey", minkey);
      conf.setInt("test.randomfactory.maxkey", maxkey);
      conf.setInt("test.randomfactory.minval", minval);
      conf.setInt("test.randomfactory.maxval", maxval);
      conf.setBoolean("test.disable.key.read", minkey == 0);
      conf.setBoolean("test.disable.val.read", minval == 0);
    }
    public int keyLen(int i) {
      return minkey + nextRand(r, maxkey - minkey);
    }
    public int valLen(int i) {
      return minval + nextRand(r, maxval - minval);
    }
  }

  @Test
  public void testRandom() throws Exception {
    Configuration conf = new Configuration();
    conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100);
    Job job = Job.getInstance(conf);
    conf = job.getConfiguration();
    conf.setInt(MRJobConfig.IO_SORT_MB, 1);
    conf.setClass("test.mapcollection.class", RandomFactory.class,
        RecordFactory.class);
    final Random r = new Random();
    final long seed = r.nextLong();
    LOG.info("SEED: " + seed);
    r.setSeed(seed);
    conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT,
        Float.toString(Math.max(0.1f, r.nextFloat())));
    RandomFactory.setLengths(conf, r, 1 << 14);
    conf.setInt("test.spillmap.records", r.nextInt(500));
    conf.setLong("test.randomfactory.seed", r.nextLong());
    runTest("random", job);
  }

  @Test
  public void testRandomCompress() throws Exception {
    Configuration conf = new Configuration();
    conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100);
    Job job = Job.getInstance(conf);
    conf = job.getConfiguration();
    conf.setInt(MRJobConfig.IO_SORT_MB, 1);
    conf.setBoolean(MRJobConfig.MAP_OUTPUT_COMPRESS, true);
    conf.setClass("test.mapcollection.class", RandomFactory.class,
        RecordFactory.class);
    final Random r = new Random();
    final long seed = r.nextLong();
    LOG.info("SEED: " + seed);
    r.setSeed(seed);
    conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT,
        Float.toString(Math.max(0.1f, r.nextFloat())));
    RandomFactory.setLengths(conf, r, 1 << 14);
    conf.setInt("test.spillmap.records", r.nextInt(500));
    conf.setLong("test.randomfactory.seed", r.nextLong());
    runTest("randomCompress", job);
  }

}