/**
 * Copyright 2014 LinkedIn Corp. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
you may
 * not use this file except in compliance with the License.
You may obtain a
 * copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package com.linkedin.mlease.regression.jobs;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.avro.Schema;
import org.apache.avro.Schema.Type;
import org.apache.avro.file.DataFileStream;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.mapred.AvroCollector;
import org.apache.avro.mapred.AvroJob;
import org.apache.avro.mapred.AvroKey;
import org.apache.avro.mapred.AvroMapper;
import org.apache.avro.mapred.AvroOutputFormat;
import org.apache.avro.mapred.AvroReducer;
import org.apache.avro.mapred.AvroValue;
import org.apache.avro.mapred.Pair;
import org.apache.commons.lang.mutable.MutableFloat;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Partitioner;
import org.apache.hadoop.mapred.Reporter;
import org.apache.log4j.Logger;


import com.linkedin.mlease.models.LinearModel;
import com.linkedin.mlease.avro.LinearModelAvro;
import com.linkedin.mlease.regression.avro.LambdaRhoMap;
import com.linkedin.mlease.regression.avro.RegressionPrepareOutput;
import com.linkedin.mlease.regression.avro.RegressionTrainOutput;
import com.linkedin.mlease.regression.avro.SampleTestLoglik;
import com.linkedin.mlease.regression.consumers.FindLinearModelConsumer;
import com.linkedin.mlease.regression.consumers.ReadLambdaMapConsumer;
import com.linkedin.mlease.regression.consumers.ReadLambdaRhoConsumer;
import com.linkedin.mlease.regression.liblinearfunc.LibLinear;
import com.linkedin.mlease.regression.liblinearfunc.LibLinearBinaryDataset;
import com.linkedin.mlease.regression.liblinearfunc.LibLinearDataset;
import com.linkedin.mlease.utils.LinearModelUtils;
import com.linkedin.mlease.utils.Util;
import com.linkedin.mapred.AbstractAvroJob;
import com.linkedin.mapred.AvroDistributedCacheFileReader;
import com.linkedin.mapred.AvroHdfsFileReader;
import com.linkedin.mapred.AvroHdfsFileWriter;
import com.linkedin.mapred.AvroUtils;
import com.linkedin.mapred.JobConfig;

public class RegressionAdmmTrain extends AbstractAvroJob
{
  public static final Logger _logger                 = Logger.getLogger(RegressionAdmmTrain.class);
  /**
   * Must have configs
   */
  // the output root dir
  public static final String OUTPUT_BASE_PATH        = "output.base.path";
  // test path for drawing test-loglik trajectory over iterations
  public static final String TEST_PATH               = "test.path";
  public static final String NUM_BLOCKS              = "num.blocks";
  public static final String LAMBDA                  = "lambda";
  public static final String NUM_ITERS               = "num.iters";
  //Regularizer L2 vs L1
  public static final String REGULARIZER         = "regularizer";
  /**
   * Optional configs
   */
  public static final String PENALIZE_INTERCEPT      = "penalize.intercept";
  public static final String REMOVE_TMP_DIR          = "remove.tmp.dir";
  public static final String EPSILON                 = "epsilon";
  public static final String LIBLINEAR_EPSILON       = "liblinear.epsilon";
  public static final String LAMBDA_MAP              = "lambda.map";
  public static final String BINARY_FEATURE          = "binary.feature";
  public static final String SHORT_FEATURE_INDEX     = "short.feature.index";
  //Flag for aggressively decreasing liblinear tolerance threshold
  public static final String AGGRESSIVE_LIBLINEAR_EPSILON_DECAY  = "aggressive.liblinear.epsilon.decay";
  // whether to do test-loglik for every iteration?
  public static final String TEST_LOGLIK_PER_ITER    = "test.loglik.per.iter";      
  // default 1, when >1 it means it replicates the clicks by several times to make sure we have better consensus
  public static final String NUM_CLICK_REPLICATES    = "num.click.replicates";
  //initialize.boost.rate: default is 0; if 0  there is no initialization; 
  //if >0, it will do mean model initialization and boost rho to rho*initilize.boost.rate  
  public static final String INITIALIZE_BOOST_RATE      = "initialize.boost.rate";
  //rho.adapt.coefficient: default is 0; if 0 not to adapt rho for each iteration; if > 0, 
  //it will   use rho.adapt.coefficient as decay parameter for exponential 
  //decay of rho at each iteration (except the first iteratio with initialization, rho is boosted with initilize.boost.rate.
  //suggested value is 0.3;
  public static final String RHO_ADAPT_COEFFICIENT = "rho.adapt.coefficient";
  //use it to pass actual rho adaptive rate to reducer at each iteration
  public static final String RHO_ADAPT_RATE = "rho.adapt.rate";
  /**
   * Not for config, but for defining constant strings
   */
  public static final String RHO                     = "rho";  
  public static final String INTERCEPT_KEY           = "intercept.key";
  public static final String U_PATH                  = "u.path";
  public static final String INIT_VALUE_PATH         = "init.value.path";
  public static final String REPORT_FREQUENCY        = "report.frequency";
  public static final String LAMBDA_RHO_MAP          = "lambda.rho.map";   
  // max number of test events
  public static final long   MAX_NTEST_EVENTS        = 1000000;                          
  

  public RegressionAdmmTrain(String jobId, JobConfig config)
  {
    super(jobId, config);
  }
  @Override
  public void run() throws Exception
  {
    _logger.info("Now running Regression Train using ADMM...");
    JobConfig props = super.getJobConfig();
    String outBasePath = props.getString(OUTPUT_BASE_PATH);
    JobConf conf = super.createJobConf();
    
    // Various configs
    int nblocks = props.getInt(NUM_BLOCKS);
    int niter = props.getInt(NUM_ITERS, 10);
    //Aggressive decay of liblinear_epsilon
    boolean aggressiveLiblinearEpsilonDecay = props.getBoolean(AGGRESSIVE_LIBLINEAR_EPSILON_DECAY,false); 
    // Getting the value of the regularizer L1/L2
    int reg = props.getInt(REGULARIZER);
    if((reg!=1) && (reg!=2))
    {
      throw new IOException("Only L1 and L2 regularization supported!");
    }
    int numClickReplicates = props.getInt(NUM_CLICK_REPLICATES, 1);
    boolean ignoreValue = props.getBoolean(BINARY_FEATURE, false);  
    float initializeBoostRate = props.getFloat(INITIALIZE_BOOST_RATE, 0);
    float rhoAdaptCoefficient = props.getFloat(RHO_ADAPT_COEFFICIENT, 0);
    
    // handling lambda and rho
    // initialize z and u and compute z-u and write to hadoop
    Map<String, LinearModel> z = new HashMap<String, LinearModel>(); // lambda ->
    List<String> lambdastr = props.getStringList(LAMBDA, ",");
    List<String> rhostr = props.getStringList(RHO, null, ",");
    if (rhostr != null)
    {
      if (rhostr.size() != lambdastr.size())
        throw new IOException("The number of rho's should be exactly the same as the number of lambda's. OR: don't claim rho!");
    }
    Map<Float, Float> lambdaRho = new HashMap<Float, Float>();
    for (int j = 0; j < lambdastr.size(); j++)
    {
      float lambda = Float.parseFloat(lambdastr.get(j));
      float rho;
      if (rhostr != null)
      {
        rho = Float.parseFloat(rhostr.get(j));
      }
      else
      {
        if (lambda <= 100)
        {
          rho = 1;
        }
        else
        {
          rho = 10;
        }
      }
      lambdaRho.put(lambda, rho);
      z.put(String.valueOf(lambda), new LinearModel());     
    }

    // Get specific lambda treatment for some features
    String lambdaMapPath = props.getString(LAMBDA_MAP, "");
    Map<String, Float> lambdaMap = new HashMap<String, Float>();
    if (!lambdaMapPath.equals(""))
    {
      AvroHdfsFileReader reader = new AvroHdfsFileReader(conf);
      ReadLambdaMapConsumer consumer = new ReadLambdaMapConsumer();
      reader.build(lambdaMapPath, consumer);
      consumer.done();
      lambdaMap = consumer.get();
    }
    _logger.info("Lambda Map has size = " + String.valueOf(lambdaMap.size()));
    // Write lambda_rho mapping into file
    String rhoPath = outBasePath + "/lambda-rho/part-r-00000.avro";
    writeLambdaRho(conf, rhoPath, lambdaRho);
    
    // test-loglik computation
    boolean testLoglikPerIter = props.getBoolean(TEST_LOGLIK_PER_ITER, false);
    DataFileWriter<GenericRecord> testRecordWriter = null;
    // test if the test file exists
    String testPath = props.getString(TEST_PATH, "");
    testLoglikPerIter = Util.checkPath(testPath);
    if (testLoglikPerIter)
    {
      List<Path> testPathList =
          AvroUtils.enumerateFiles(conf, new Path(testPath));
      if (testPathList.size() > 0)
      {
        testPath = testPathList.get(0).toString();
        _logger.info("Sample test path = " + testPath);
      
        AvroHdfsFileWriter<GenericRecord> writer =
            new AvroHdfsFileWriter<GenericRecord>(conf, outBasePath
                + "/sample-test-loglik/write-test-00000.avro", SampleTestLoglik.SCHEMA$);
        testRecordWriter = writer.get();
      }
    }
    if (testRecordWriter == null)
    {
      testLoglikPerIter = false;
      _logger.info("test.loglik.per.iter=false or test path doesn't exist or is empty! So we will not output test loglik per iteration.");
    }
    else
    {
      testRecordWriter.close();
    }
    
    MutableFloat bestTestLoglik = new MutableFloat(-9999999);
    //Initialize z by mean model 
    if (initializeBoostRate  > 0 && reg==2)
    {
      _logger.info("Now start mean model initializing......");
      // Different paths for L1 vs L2 set from job file
      String initalModelPath;
      initalModelPath = outBasePath + "/initialModel";

      Path initalModelPathFromNaiveTrain = new Path(outBasePath, "models");  
      JobConfig propsIni = JobConfig.clone(props);
      if (!propsIni.containsKey(LIBLINEAR_EPSILON))
      {
        propsIni.put(LIBLINEAR_EPSILON, 0.01);
      }
      propsIni.put(RegressionNaiveTrain.HEAVY_PER_ITEM_TRAIN, "true");
      propsIni.put(LAMBDA_MAP, lambdaMapPath);
      propsIni.put(REMOVE_TMP_DIR, "false");
      
      // run job
      RegressionNaiveTrain initializationJob = new RegressionNaiveTrain(super.getJobId()+"_ADMMInitialization",propsIni);
      initializationJob.run();
      
      FileSystem fs = initalModelPathFromNaiveTrain.getFileSystem(conf);
      if (fs.exists(new Path(initalModelPath)))
      {
        fs.delete(new Path(initalModelPath),true);
      }
      fs.rename(initalModelPathFromNaiveTrain, new Path(initalModelPath));
      // set up lambda
      Set<Float> lambdaSet = new HashSet<Float>();
      for (String l : lambdastr)
      {
        lambdaSet.add(Float.parseFloat(l));
      }
      // Compute Mean model as initial model
      z = LinearModelUtils.meanModel(conf, initalModelPath, nblocks, lambdaSet.size(), true);      

      if (testLoglikPerIter)
      {       
        updateLogLikBestModel(conf, 0,  z, testPath, ignoreValue, bestTestLoglik, outBasePath, numClickReplicates);
      }     
    }
    
    double mindiff = 99999999;
    float liblinearEpsilon = 0.01f;   
    int i;
    for (i = 1; i <= niter; i++)
    {
      _logger.info("Now starting iteration " + String.valueOf(i));
      // set up configuration
      props.put(AbstractAvroJob.OUTPUT_PATH, outBasePath + "/iter-" + String.valueOf(i));
      conf =
          createJobConf(AdmmMapper.class,
                        AdmmReducer.class,
                        Pair.getPairSchema(Schema.create(Type.INT),
                                           RegressionPrepareOutput.SCHEMA$),
                                           RegressionTrainOutput.SCHEMA$);
      conf.setPartitionerClass(AdmmPartitioner.class);
      //AvroUtils.setSpecificReducerInput(conf, true);
      conf.setInt(NUM_BLOCKS, nblocks);
      //Added for L1/L2
      conf.setInt(REGULARIZER, reg);
      conf.setLong(REPORT_FREQUENCY, props.getLong(REPORT_FREQUENCY, 1000000));
      //boolean ignoreValue = props.getBoolean(BINARY_FEATURE, false);
      conf.setBoolean(BINARY_FEATURE, ignoreValue);
      conf.setBoolean(SHORT_FEATURE_INDEX, props.getBoolean(SHORT_FEATURE_INDEX, false));

      boolean penalizeIntercept = props.getBoolean(PENALIZE_INTERCEPT, false);
      String interceptKey = props.getString(INTERCEPT_KEY, LibLinearDataset.INTERCEPT_NAME);
      conf.set(INTERCEPT_KEY, interceptKey);
      //int schemaType = props.getInt(SCHEMA_TYPE, 1);

      // compute and store u into file
      // u = uplusx - z
      String uPath = outBasePath + "/iter-" + String.valueOf(i) + "/u/part-r-00000.avro";
      if (i == 1)
      {
        LinearModelUtils.writeLinearModel(conf, uPath, new HashMap<String, LinearModel>());
        if (initializeBoostRate > 0 && reg==2)
        {
          
          conf.setFloat(RHO_ADAPT_RATE, initializeBoostRate);
        }
      }
      else
      {
          String uplusxPath = outBasePath + "/iter-" + String.valueOf(i - 1) + "/model";
          computeU(conf, uPath, uplusxPath, z);
        if(rhoAdaptCoefficient > 0)
        {
          float curRhoAdaptRate = (float) Math.exp(-(i-1)*rhoAdaptCoefficient);
          conf.setFloat(RHO_ADAPT_RATE, curRhoAdaptRate);
        }        
      }
      // write z into file
      String zPath = outBasePath + "/iter-" + String.valueOf(i) + "/init-value/part-r-00000.avro";
      LinearModelUtils.writeLinearModel(conf, zPath, z);

      // run job
      String outpath = outBasePath + "/iter-" + String.valueOf(i) + "/model";
      conf.set(U_PATH, uPath);
      conf.set(INIT_VALUE_PATH, zPath);
      conf.set(LAMBDA_RHO_MAP, rhoPath);
      if (i > 1 && mindiff < 0.001 && !aggressiveLiblinearEpsilonDecay) // need to get a more accurate estimate from liblinear
      {
        liblinearEpsilon = liblinearEpsilon / 10;
      }
      else if(aggressiveLiblinearEpsilonDecay && i > 5)
      {
          liblinearEpsilon = liblinearEpsilon / 10;
      }
      conf.setFloat(LIBLINEAR_EPSILON, liblinearEpsilon);
       //Added for logging aggressive decay
      _logger.info("Liblinear Epsilon for iter = " 
          + String.valueOf(i) + " is: " + String.valueOf(liblinearEpsilon));
      _logger.info("aggressiveLiblinearEpsilonDecay="+aggressiveLiblinearEpsilonDecay);
      AvroOutputFormat.setOutputPath(conf, new Path(outpath));
      AvroUtils.addAvroCacheFiles(conf, new Path(uPath));
      AvroUtils.addAvroCacheFiles(conf, new Path(zPath));
      AvroUtils.addAvroCacheFiles(conf, new Path(rhoPath));
      conf.setNumReduceTasks(nblocks * lambdastr.size());
      AvroJob.setInputSchema(conf, RegressionPrepareOutput.SCHEMA$);
      AvroUtils.runAvroJob(conf);
      // Load the result from the last iteration
      // compute z and u given x
      
      
      Map<String, LinearModel> xbar =
          LinearModelUtils.meanModel(conf, outpath, nblocks, lambdaRho.size(), true);
      Map<String, LinearModel> ubar = LinearModelUtils.meanModel(conf, uPath, nblocks, lambdaRho.size(), false);
      Map<String, LinearModel> lastz = new HashMap<String, LinearModel>();
      for (String k : z.keySet())
      {
        lastz.put(k, z.get(k).copy());
      }
      for (String lambda : xbar.keySet())
      {
        LinearModel thisz = z.get(lambda);
        thisz.clear();
        float l = Float.parseFloat(lambda);
        float r = lambdaRho.get(l);
        double weight;
        //L2 regularization
        if(reg==2) 
         {
            _logger.info("Running code for regularizer = " + String.valueOf(reg));
            weight = nblocks * r / (l + nblocks * r);
            Map<String, Double> weightmap = new HashMap<String, Double>();
              for (String k : lambdaMap.keySet())
              {
                 weightmap.put(k, nblocks * r / (lambdaMap.get(k) + nblocks * r + 0.0));
              }
              thisz.linearCombine(1.0, weight, xbar.get(lambda), weightmap);
              if (!ubar.isEmpty())
              {
                  thisz.linearCombine(1.0, weight, ubar.get(lambda), weightmap);
              }
              if (!penalizeIntercept)
              {
                 if (ubar.isEmpty())
                 {
                      thisz.setIntercept(xbar.get(lambda).getIntercept());
                 }
                 else
                 {
                      thisz.setIntercept(xbar.get(lambda).getIntercept()
                             + ubar.get(lambda).getIntercept());
                 }
              }
              z.put(lambda, thisz);
          }
          else
        {
          // L1 regularization

          _logger.info("Running code for regularizer = " + String.valueOf(reg));
          weight = l / (r * nblocks + 0.0);
          Map<String, Double> weightmap = new HashMap<String, Double>();
          for (String k : lambdaMap.keySet())
          {
            weightmap.put(k, lambdaMap.get(k) / (r * nblocks + 0.0));
          }
          // LinearModel thisz = new LinearModel();
          thisz.linearCombine(1.0, 1.0, xbar.get(lambda));
          if (!ubar.isEmpty())
          {
            thisz.linearCombine(1.0, 1.0, ubar.get(lambda));
          }
          // Iterative Thresholding
          Map<String, Double> thisCoefficients = thisz.getCoefficients();
          for (String k : thisCoefficients.keySet())
          {
            double val = thisCoefficients.get(k);
            if (val > weight)
            {
              thisCoefficients.put(k, val - weight);
            }
            else if (val < -weight)
            {
              thisCoefficients.put(k, val + weight);
            }
          }
          thisz.setCoefficients(thisCoefficients);
          if (!penalizeIntercept)
          {
            if (ubar.isEmpty())
            {
              thisz.setIntercept(xbar.get(lambda).getIntercept());
            }
            else
            {
              thisz.setIntercept(xbar.get(lambda).getIntercept()
                  + ubar.get(lambda).getIntercept());
            }
          }
          z.put(lambda, thisz);
        }
      }
      xbar.clear();
      ubar.clear();
      // Output max difference between last z and this z
      mindiff = 99999999;
      double maxdiff = 0;
      for (String k : z.keySet())
      {
        LinearModel tmp = lastz.get(k);
        if (tmp == null)
          tmp = new LinearModel();
        tmp.linearCombine(1, -1, z.get(k));
        double diff = tmp.maxAbsValue();
        _logger.info("For lambda=" + k + ": Max Difference between last z and this z = "
            + String.valueOf(diff));
        tmp.clear();
        if (mindiff > diff)
          mindiff = diff;
        if (maxdiff < diff)
          maxdiff = diff;
      }
      double epsilon = props.getDouble(EPSILON, 0.0001);
      // remove tmp files?
      if (props.getBoolean(REMOVE_TMP_DIR, false) && i >= 2)
      {
        FileSystem fs = FileSystem.get(conf);
        fs.delete(new Path(outBasePath + "/iter-" + String.valueOf(i - 1)), true);
      }
      // Output testloglik and update best model
      if (testLoglikPerIter)
      {
        updateLogLikBestModel(conf,
                              i,
                              z,
                              testPath,
                              ignoreValue,
                              bestTestLoglik,
                              outBasePath,
                              numClickReplicates);
      }

      if (maxdiff < epsilon && liblinearEpsilon <= 0.00001)
      {
        break;
      }
    }

    // write z into file
    String zPath = outBasePath + "/final-model/part-r-00000.avro";
    LinearModelUtils.writeLinearModel(conf, zPath, z);
    // remove tmp files?
    if (props.getBoolean(REMOVE_TMP_DIR, false))
    {
      FileSystem fs = FileSystem.get(conf);
      Path initalModelPath = new Path(outBasePath + "/initialModel");
      if (fs.exists(initalModelPath))
      {
        fs.delete(initalModelPath, true);
      }
      for (int j = i - 2; j <= i; j++)
      {
        Path deletepath = new Path(outBasePath + "/iter-" + String.valueOf(j));
        if (fs.exists(deletepath))
        {
          fs.delete(deletepath, true);
        }
      }
      fs.delete(new Path(outBasePath + "/tmp-data"), true);
    }

  }

  public static class AdmmMapper extends
  AvroMapper<RegressionPrepareOutput, Pair<Integer, RegressionPrepareOutput>>
  {
    private ReadLambdaRhoConsumer _lambdaRhoConsumer = new ReadLambdaRhoConsumer();

    @Override
    public void setConf(Configuration conf)
    {
      super.setConf(conf);
      if (conf == null)
      {
        return;
      }
      AvroDistributedCacheFileReader lambdaRhoReader =
          new AvroDistributedCacheFileReader(new JobConf(conf));
      try
      {
        lambdaRhoReader.build(conf.get(LAMBDA_RHO_MAP), _lambdaRhoConsumer);
        _lambdaRhoConsumer.done();
      }
      catch (IOException e)
      {
        e.printStackTrace();
      }
      _logger.info("lambda file:" + conf.get(LAMBDA_RHO_MAP));
      _logger.info("Loaded " + String.valueOf(_lambdaRhoConsumer.get().size())
                   + " lambdas.");
    }

    @Override
    public void map(RegressionPrepareOutput data,
                    AvroCollector<Pair<Integer, RegressionPrepareOutput>> collector,
                    Reporter reporter) throws IOException
    {
      Integer key = Integer.parseInt(data.key.toString());
      for (int i = 0; i < _lambdaRhoConsumer.get().size(); i++)
      {
        int newkey = key * _lambdaRhoConsumer.get().size() + i;
        // String newkey = String.valueOf(lambda)+"#"+key;
        data.key = String.valueOf(newkey);
        Pair<Integer, RegressionPrepareOutput> outPair =
            new Pair<Integer, RegressionPrepareOutput>(newkey, data);
        collector.collect(outPair);
      }
    }
  }

  public static class AdmmPartitioner implements
  Partitioner<AvroKey<Integer>, AvroValue<RegressionPrepareOutput>>
  {
    @Override
    public void configure(JobConf conf)
    {
    }

    @Override
    public int getPartition(AvroKey<Integer> key,
                            AvroValue<RegressionPrepareOutput> value,
                            int numPartitions)
    {
      Integer keyInt = key.datum();
      if (keyInt < 0 || keyInt >= numPartitions)
      {
        throw new RuntimeException("Map key is wrong! key has to be in the range of [0,numPartitions-1].");
      }
      return keyInt;
    }
  }

  public static class AdmmReducer extends
  AvroReducer<Integer, RegressionPrepareOutput, GenericData.Record>
  {
    String                        _interceptKey;
    long                          _reportfreq;
    boolean                       _binaryFeature;
    boolean                       _shortFeatureIndex;
    float                         _liblinearEpsilon;
    String                        _uPath;
    String                        _initValuePath;
    JobConf                       _conf;
    private ReadLambdaRhoConsumer _lambdaRhoConsumer = new ReadLambdaRhoConsumer();
    private List<Float>           _lambdaOrderedList;
    private float                _rhoAdaptRate;

    @Override
    public void setConf(Configuration conf)
    {
      super.setConf(conf);
      if (conf == null)
      {
        return;
      }
      _interceptKey = conf.get(INTERCEPT_KEY, LibLinearDataset.INTERCEPT_NAME);
      _reportfreq = conf.getLong(REPORT_FREQUENCY, 1000000);
      _binaryFeature = conf.getBoolean(BINARY_FEATURE, false);
      _shortFeatureIndex = conf.getBoolean(SHORT_FEATURE_INDEX, false);
      _liblinearEpsilon = conf.getFloat(LIBLINEAR_EPSILON, 0.01f);
      _rhoAdaptRate = conf.getFloat(RHO_ADAPT_RATE,  1.0f);
      AvroDistributedCacheFileReader lambdaRhoReader =
          new AvroDistributedCacheFileReader(new JobConf(conf));
      try
      {
        lambdaRhoReader.build(conf.get(LAMBDA_RHO_MAP), _lambdaRhoConsumer);
        _lambdaRhoConsumer.done();
      }
      catch (IOException e)
      {
        e.printStackTrace();
      }
      _uPath = conf.get(U_PATH);
      _initValuePath = conf.get(INIT_VALUE_PATH);
      _conf = new JobConf(conf);
      Set<Float> lambdaSet = _lambdaRhoConsumer.get().keySet();
      _lambdaOrderedList = new ArrayList<Float>(lambdaSet);
      java.util.Collections.sort(_lambdaOrderedList);
    }

    @Override
    public void reduce(Integer NumKey,
                       Iterable<RegressionPrepareOutput> values,
                       AvroCollector<GenericData.Record> collector,
                       Reporter reporter) throws IOException
    {
      int nlambdas = _lambdaRhoConsumer.get().size();
      float lambda = _lambdaOrderedList.get(NumKey % nlambdas);
      int partitionID = (int) NumKey / nlambdas;
      String key = String.valueOf(lambda) + "#" + String.valueOf(partitionID);
      // float lambda = Float.parseFloat(Util.getLambda(key.toString()));
      double rho = _lambdaRhoConsumer.get().get(lambda); 
      if (_rhoAdaptRate!=1.0)
      {
        rho = rho* (double) _rhoAdaptRate;
        _logger.info("Adaptive rate is " + _rhoAdaptRate);
        _logger.info("Adaptive rho is " + rho);
      }
     
      // get prior mean and init value
      FindLinearModelConsumer _uConsumer = new FindLinearModelConsumer(key.toString());
      FindLinearModelConsumer _initValueConsumer =
          new FindLinearModelConsumer(Util.getLambda(key.toString()));
      AvroDistributedCacheFileReader uReader = new AvroDistributedCacheFileReader(_conf);
      uReader.build(_uPath, _uConsumer);
      _uConsumer.done();
      _logger.info("Loaded u for the key, size:"
          + _uConsumer.get().getCoefficients().size());
      AvroDistributedCacheFileReader initValueReader =
          new AvroDistributedCacheFileReader(_conf);
      initValueReader.build(_initValuePath, _initValueConsumer);
      _initValueConsumer.done();
      _logger.info("Loaded initial value of the model, size:"
          + _initValueConsumer.get().getCoefficients().size());
      GenericData.Record output = new GenericData.Record(RegressionTrainOutput.SCHEMA$);
      // Prepare the data set
      LibLinearDataset dataset;
      if (_binaryFeature)
      {
        dataset = new LibLinearBinaryDataset(1.0, _shortFeatureIndex);
      }
      else
      {
        dataset = new LibLinearDataset(1.0);
      }
      for (RegressionPrepareOutput record : values)
      {
        dataset.addInstanceAvro(record);
      }
      dataset.finish();
      // Prepare the initial value
      LinearModel initvalue = _initValueConsumer.get();
      Map<String, Double> initvaluemap = initvalue.toMap(LibLinearDataset.INTERCEPT_NAME);
      // Prepare the prior mean
      LinearModel priormean = _uConsumer.get().copy();
      // Compute z minus u
      priormean.linearCombine(-1, 1, initvalue);
      Map<String, Double> priormeanmap = priormean.toMap(LibLinearDataset.INTERCEPT_NAME);
      // Run liblinear
      LibLinear liblinear = new LibLinear();
      liblinear.setReporter(reporter, _reportfreq);
      String option = "epsilon=" + String.valueOf(_liblinearEpsilon);
      try
      {
        liblinear.train(dataset, initvaluemap, priormeanmap, null, 1.0 / rho, option);
        LinearModel model = liblinear.getLinearModel();
        output.put("key", key);
        output.put("model", model.toAvro(LibLinearDataset.INTERCEPT_NAME));
        LinearModel uplusx = _uConsumer.get();
        uplusx.linearCombine(1, 1, model);
        output.put("uplusx", uplusx.toAvro(LibLinearDataset.INTERCEPT_NAME));
      }
      catch (Exception e)
      {
        throw new IOException("Model fitting error!", e);
      }
      collector.collect(output);
    }
  }

  private void writeLambdaRho(JobConf conf, String path, Map<Float, Float> lambda_rho) throws IOException
  {
    AvroHdfsFileWriter<GenericRecord> writer =
        new AvroHdfsFileWriter<GenericRecord>(conf, path, LambdaRhoMap.SCHEMA$);
    DataFileWriter<GenericRecord> recordWriter = writer.get();
    for (Float k : lambda_rho.keySet())
    {
      GenericRecord record = new GenericData.Record(LambdaRhoMap.SCHEMA$);
      record.put("lambda", k);
      record.put("rho", lambda_rho.get(k));
      recordWriter.append(record);
    }
    recordWriter.close();
  }
  // u = u + x - z
  private void computeU(JobConf conf, String uPath, String uplusxPath, Map<String, LinearModel> z) throws IOException
  {
    AvroHdfsFileWriter<GenericRecord> writer =
        new AvroHdfsFileWriter<GenericRecord>(conf, uPath, LinearModelAvro.SCHEMA$);
    DataFileWriter<GenericRecord> recordwriter = writer.get();
    // read u+x
    for (Path path : Util.findPartFiles(conf, new Path(uplusxPath)))
    {
      DataFileStream<Object> stream = AvroUtils.getAvroDataStream(conf, path);
      while (stream.hasNext())
      {
        GenericData.Record record = (GenericData.Record) stream.next();
        String partitionID = Util.getStringAvro(record, "key", false);
        if (record.get("uplusx") != null)
        {
          String lambda = Util.getLambda(partitionID);
          LinearModel newu =
              new LinearModel(LibLinearDataset.INTERCEPT_NAME, (List<?>) record.get("uplusx"));
          newu.linearCombine(1.0, -1.0, z.get(lambda));
          GenericData.Record newvaluemap =
              new GenericData.Record(LinearModelAvro.SCHEMA$);
          List modellist = newu.toAvro(LibLinearDataset.INTERCEPT_NAME);
          newvaluemap.put("key", partitionID);
          newvaluemap.put("model", modellist);
          recordwriter.append(newvaluemap);
        }
      }
    }
    recordwriter.close();
  }
  private Map<String, Double> testloglik(JobConf conf, Map<String, LinearModel> modelmap,
                                         String testPath,
                                         int num_click_replicates,
                                         boolean ignore_value) throws IOException
  {
    DataFileStream<Object> stream = AvroUtils.getAvroDataStream(conf, new Path(testPath));
    Map<String, Double> loglik = new HashMap<String, Double>();
    for (String k : modelmap.keySet())
    {
      loglik.put(k, 0.0);
    }
    double n = 0;
    long nrecords = 0;
    while (stream.hasNext())
    {
      GenericData.Record record = (GenericData.Record) stream.next();
      for (String k : modelmap.keySet())
      {
        double tmp = loglik.get(k);
        loglik.put(k,
                   tmp
                   + modelmap.get(k).evalInstanceAvro(record,
                                                      true,
                                                      num_click_replicates,
                                                      ignore_value));
      }
      double weight = 1;
      if (record.get("weight")!=null)
      {
        weight = Double.parseDouble(record.get("weight").toString());
      }
      nrecords++;
      n += weight;
      if (nrecords >= MAX_NTEST_EVENTS)
      {
        break;
      }
    }
    for (String k : loglik.keySet())
    {
      double tmp = loglik.get(k);
      loglik.put(k, tmp / n);
    }
    _logger.info("Finished computing testloglik...Evaluated #test records=" + nrecords);
    return loglik;
  }
  private void updateLogLikBestModel(JobConf conf, int niter,  Map<String, LinearModel> z, String testPath, 
                                     boolean ignoreValue, MutableFloat bestTestLoglik, String outBasePath, 
                                     int  numClickReplicates) throws IOException
   {   
     Map<String, Double> loglik;
     loglik = testloglik(conf, z, testPath, 1, ignoreValue);
     
     AvroHdfsFileWriter<GenericRecord> writer =
         new AvroHdfsFileWriter<GenericRecord>(conf, outBasePath
             + "/sample-test-loglik/iteration-"+niter +".avro", SampleTestLoglik.SCHEMA$);
     DataFileWriter<GenericRecord> testRecordWriter = writer.get();  

     for (String k : z.keySet())
     {     
       GenericData.Record valuemap = new GenericData.Record(SampleTestLoglik.SCHEMA$);
       valuemap.put("iter", niter);
       valuemap.put("testLoglik", loglik.get(k).floatValue());
       valuemap.put("lambda", k);
       testRecordWriter.append(valuemap);
       _logger.info("Sample test loglik for lambda=" + k + " is: "
           + String.valueOf(loglik.get(k)));
      
       // output best model up to now
       if (loglik.get(k) > bestTestLoglik.floatValue() && niter>0)
       {
         String bestModelPath = outBasePath + "/best-model/best-iteration-" + niter + ".avro";
         FileSystem fs = FileSystem.get(conf);
         fs.delete(new Path(outBasePath + "/best-model"), true);
         LinearModelUtils.writeLinearModel(conf, bestModelPath, z.get(k), k);
         bestTestLoglik.setValue(loglik.get(k).floatValue());
       }
     }
     testRecordWriter.close();
   }   
}