/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.mapred;

import java.io.IOException;
import java.util.List;

import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.IFile.Writer;
import org.apache.hadoop.mapred.MemoryBlock.KeyValuePairIterator;
import org.apache.hadoop.mapred.Task.TaskReporter;
import org.apache.hadoop.util.PriorityQueue;

class ReducePartition<K extends BytesWritable, V extends BytesWritable> extends
    BasicReducePartition<K, V> {

  class KeyValueSortedArray extends
      PriorityQueue<KeyValuePairIterator> implements
      KeyValueSpillIterator {
    int totalRecordsNum = 0;
    MemoryBlockIndex memoryBlockIndex;
    boolean first = true;

    public KeyValueSortedArray( List<MemoryBlock> sortedMemBlks,
         int collectedRecordsNum) {
      totalRecordsNum = collectedRecordsNum;
      initHeap(sortedMemBlks);
      memoryBlockIndex = new MemoryBlockIndex();
    }

    public void reset( List<MemoryBlock> sortedMemBlks,
         int collectedRecordsNum) {
      totalRecordsNum = collectedRecordsNum;
      initHeap(sortedMemBlks);
      memoryBlockIndex = new MemoryBlockIndex();
    }

    private void initHeap(List<MemoryBlock> sortedMemBlks) {
      initialize(sortedMemBlks.size());
      clear();
      for (MemoryBlock memBlk : sortedMemBlks) {
        put(memBlk.iterator());
      }
      first = true;
    }

    public MemoryBlockIndex next() {
      if (totalRecordsNum == 0 || size() == 0) {
        return null;
      }

      if (!first) {
        KeyValuePairIterator top = top();
        if (top.hasNext()) {
          top.next();
          adjustTop();
        } else {
          MemoryBlock freeMemBlk = pop().getMemoryBlock();
          putMemoryBlockFree(freeMemBlk);
          if (size() == 0) {
            return null;
          }
        }
      } else {
        first = false;
      }

      KeyValuePairIterator top = top();
      MemoryBlock memBlock = top.getMemoryBlock();
      memoryBlockIndex.setMemoryBlockIndex(memBlock, top
          .getCurrentReadPos());
      return memoryBlockIndex;
    }

    @Override
    protected boolean lessThan(final Object a, final Object b) {
      KeyValuePairIterator kv1 = (KeyValuePairIterator) a;
      KeyValuePairIterator kv2 = (KeyValuePairIterator) b;
      int cmpRet =
          LexicographicalComparerHolder.compareBytes(kvbuffer, kv1
              .getCurrentOffset(), kv1.getCurrentKeyLen(), kv2
              .getCurrentOffset(), kv2.getCurrentKeyLen());
      return cmpRet < 0;
    }

    @Override
    public int getRecordNumber() {
      return totalRecordsNum;
    }
  }

  protected KeyValueSortedArray keyValueSortArray;

  public ReducePartition(int reduceNum,
      MemoryBlockAllocator memoryBlockAllocator, byte[] kvBuffer,
      BlockMapOutputCollector<K, V> collector, TaskReporter reporter)
      throws IOException {
    super(reduceNum, memoryBlockAllocator, kvBuffer, collector, reporter);
  }

  public void putMemoryBlockFree(MemoryBlock memoryBlock) {
    memoryBlockAllocator.freeMemoryBlock(memoryBlock);
  }

  public int collect(K key, V value) throws IOException {
    int keyLen = key.getLength();
    int valLen = value.getLength();

    int recordSize = keyLen + valLen;
    
    if (currentBlock == null || recordSize > currentBlock.left()) {
      MemoryBlock newBlock =
          memoryBlockAllocator.allocateMemoryBlock(recordSize);
      if (newBlock == null) {
        collector.spillSingleRecord(key, value, partition);
        return recordSize;
      }
      if(currentBlock != null) {
        closeCurrentMemoryBlock();
      }
      currentBlock = newBlock;
    }
    currentBlock.collectKV(kvbuffer, key, value);
    collectedRecordsNum = collectedRecordsNum + 1;
    collectedBytesSize = collectedBytesSize + recordSize;
    return recordSize;
  }

  public KeyValueSpillIterator getKeyValueSpillIterator() {
    return keyValueSortArray;
  }

  public IndexRecord spill(JobConf job, FSDataOutputStream out,
      Class<K> keyClass, Class<V> valClass, CompressionCodec codec,
      Counter spillCounter) throws IOException {
    IFile.Writer<K, V> writer = null;
    IndexRecord rec = new IndexRecord();
    long segmentStart = out.getPos();
    try {
      writer = new Writer<K, V>(job, out, keyClass, valClass, codec,
          spillCounter);
      // spill directly
      KeyValueSpillIterator kvSortedArray = this.getKeyValueSpillIterator();
      MemoryBlockIndex memBlkIdx = kvSortedArray.next();
      while (memBlkIdx != null) {
        int pos = memBlkIdx.getIndex();
        MemoryBlock memBlk = memBlkIdx.getMemoryBlock();
        writer.append(kvbuffer, memBlk.offsets[pos], memBlk.keyLenArray[pos],
            memBlk.valueLenArray[pos]);
        memBlkIdx = kvSortedArray.next();
      }
    } finally {
      // close the writer
      if (null != writer)
        writer.close();
    }
    rec.startOffset = segmentStart;
    rec.rawLength = writer.getRawLe