/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.horn.utils;

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.io.FloatVectorWritable;
import org.apache.hama.commons.math.DenseFloatVector;

public class MNISTConverter {

  private static int PIXELS = 28 * 28;

  private static float rescale(float x) {
    return 1 - (255 - x) / 255;
  }

  public static void main(String[] args) throws Exception {
    if (args.length < 3) {
      System.out.println("Usage: <TRAINING_DATA> <LABELS_DATA> <OUTPUT_PATH>");
      System.out
          .println("ex) train-images.idx3-ubyte train-labels.idx1-ubyte /tmp/mnist.seq");
      System.exit(1);
    }

    String training_data = args[0];
    String labels_data = args[1];
    String output = args[2];

    HamaConfiguration conf = new HamaConfiguration();
    conf.set("dfs.block.size", "11554432");
    FileSystem fs = FileSystem.get(conf);

    DataInputStream imagesIn = new DataInputStream(new FileInputStream(
        new File(training_data)));
    DataInputStream labelsIn = new DataInputStream(new FileInputStream(
        new File(labels_data)));

    imagesIn.readInt(); // Magic number
    int count = imagesIn.readInt();
    labelsIn.readInt(); // Magic number
    labelsIn.readInt(); // Count
    imagesIn.readInt(); // Rows
    imagesIn.readInt(); // Cols

    System.out.println("Writing " + count + " samples on " + output);

    byte[][] images = new byte[count][PIXELS];
    byte[] labels = new byte[count];
    for (int n = 0; n < count; n++) {
      imagesIn.readFully(images[n]);
      labels[n] = labelsIn.readByte();
    }

    @SuppressWarnings("deprecation")
    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(
        output), LongWritable.class, FloatVectorWritable.class);

    for (int i = 0; i < count; i++) {
      float[] vals = new float[PIXELS + 10];
      for (int j = 0; j < PIXELS; j++) {
        vals[j] = rescale((images[i][j] & 0xff));
      }
      int label = (labels[i] & 0xff);
      // embedding to one-hot vector
      for (int j = 0; j < 10; j++) {
        if (j == label)
          vals[PIXELS + j] = 1.0f;
        else
          vals[PIXELS + j] = 0.0f;
      }

      writer.append(new LongWritable(), new FloatVectorWritable(
          new DenseFloatVector(vals)));
    }

    imagesIn.close();
    labelsIn.close();
    writer.close();
  }

}