/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.mapred;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.mapred.protocal.FairSchedulerProtocol;
import org.apache.hadoop.mapreduce.TaskType;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.util.StringUtils;


/**
 * Moves slots between two MapReduce clusters which runs TaskTrackers on the
 * same set of machines
 */
public class HourGlass implements Runnable {

  static {
    Configuration.addDefaultResource("hour-glass.xml");
  }

  public final static String SERVERS_KEY = "mapred.hourglass.fairscheduler.servers";
  public final static String WEIGHTS_KEY = "mapred.hourglass.fairscheduler.weights";
  public final static String MAX_MAP_KEY = "mapred.hourglass.map.tasks.maximum";
  public final static String MAX_REDUCE_KEY = "mapred.hourglass.reduce.tasks.maximum";
  public final static String CPU_MAP_KEY = "mapred.hourglass.cpus.to.maptasks";
  public final static String CPU_REDUCE_KEY = "mapred.hourglass.cpus.to.reducetasks";
  public final static String INTERVAL_KEY = "mapred.hourglass.update.interval";
  public final static String SHARE_THRESHOLD_KEY = "mapred.hourglass.share.threshold";

  // The if the share is lower than this threshold, the cluster gets 0 slots
  float shareThreshold = 0.01F;

  public static Log LOG = LogFactory.getLog(HourGlass.class);
  long updateInterval = 10000L;
  volatile boolean running = true;
  Configuration conf;
  Cluster clusters[] = new Cluster[2];

  // Stores the initial maximum slot limit loaded from the conf
  int defaultMaxMapSlots;
  int defaultMaxReduceSlots;

  // Stores the initial #CPU to maximum slot limit loaded from the conf
  Map<Integer, Integer> defaultCpuToMaxMapSlots = null;
  Map<Integer, Integer> defaultCpuToMaxReduceSlots = null;

  final static TaskType MAP_AND_REDUCE[] =
      new TaskType[] {TaskType.MAP, TaskType.REDUCE};

  public HourGlass(Configuration conf) throws IOException {
    this.conf = conf;
    defaultMaxMapSlots = conf.getInt(MAX_MAP_KEY, Integer.MAX_VALUE);
    defaultMaxReduceSlots = conf.getInt(MAX_REDUCE_KEY, Integer.MAX_VALUE);
    defaultCpuToMaxMapSlots = loadCpuToMaxSlots(conf, TaskType.MAP);
    defaultCpuToMaxReduceSlots = loadCpuToMaxSlots(conf, TaskType.REDUCE);
    shareThreshold = conf.getFloat(SHARE_THRESHOLD_KEY, shareThreshold);
    try {
      String config;
      config = conf.get(SERVERS_KEY);
      String addresses[] = config.replaceAll("\\s", "").split(",");
      config = conf.get(WEIGHTS_KEY);
      double weights[] = new double[2];
      String str[] = config.replaceAll("\\s", "").split(",");
      weights[0] = Double.parseDouble(str[0]);
      weights[1] = Double.parseDouble(str[1]);
      if (weights[0] < 0 || weights[1] < 0 ||
          (weights[0] == 0 && weights[1] == 0)) {
        throw new IOException();
      }
      clusters[0] = new Cluster(addresses[0], weights[0], conf);
      clusters[1] = new Cluster(addresses[1], weights[1], conf);
    } catch (Exception e) {
      String msg = "Must assign exactly two server addresses and " +
          "the corresponding positive weights in hour-glass.xml";
      LOG.error(msg);
      throw new IOException(msg);
    }
    updateInterval = conf.getLong(WEIGHTS_KEY, updateInterval);
  }

  public Map<Integer, Integer> loadCpuToMaxSlots(
      Configuration conf, TaskType type) {
    String config = type == TaskType.MAP ?
        conf.get(CPU_MAP_KEY) : conf.get(CPU_REDUCE_KEY);
    Map<Integer, Integer> defaultCpuToMaxSlots =
        new HashMap<Integer, Integer>();
    if (config != null) {
      for (String s : config.replaceAll("\\s", "").split(",")) {
        String pair[] = s.split(":");
        int cpus = Integer.parseInt(pair[0]);
        int tasks = Integer.parseInt(pair[1]);
        LOG.info(String.format(
            "Number of CPUs to tasks. %s CPU : %s %s", cpus, tasks, type));
        defaultCpuToMaxSlots.put(cpus, tasks);
      }
    }
    return defaultCpuToMaxSlots;
  }

  /**
   * Hold the states of one MapReduce cluster
   */
  static class Cluster {
    FairSchedulerProtocol client;
    Map<String, TaskTrackerStatus> taskTrackers =
        new HashMap<String, TaskTrackerStatus>();
    String address;
    double weight;               // Higher weight will get more share
    int runnableMaps;            // Runnable maps on the cluster
    int runnableReduces;         // Runnable reduces on the cluster
    double targetMapShare;       // The share of maps to achieve
    double targetReduceShare;    // The share of reduces to achieve

    Cluster(String address, double weight, Configuration conf)
        throws IOException {
      this.client = createClient(address, conf);
      this.weight = weight;
      this.address = address;
    }

    /**
     * Obtain the cluster information from RPC
     * @throws IOException
     */
    void updateClusterStatus() throws IOException {
      taskTrackers.clear();
      for (TaskTrackerStatus