/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.hdl.tensorflow.yarn.client;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ClassUtil;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
import org.hdl.tensorflow.yarn.appmaster.ApplicationMaster;
import org.hdl.tensorflow.yarn.appmaster.ClusterSpec;
import org.hdl.tensorflow.yarn.util.Constants;
import org.hdl.tensorflow.yarn.util.Utils;

import java.io.File;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Launch TensorFlow Cluster on YARN and gets {@link ClusterSpec}.
 */
public class LaunchCluster implements Client.Command {
  private static final Log LOG = LogFactory.getLog(LaunchCluster.class);
  private final Configuration conf;
  private final YarnClient yarnClient;
  private final String appName;
  private final Integer amMemory;
  private final Integer amVCores;
  private final String amQueue;
  private final Integer containerMemory;
  private final Integer containerVCores;
  private final String tfLib;
  private final String tfJar;
  private final Integer workerNum;
  private final Integer psNum;

  public LaunchCluster(Configuration conf, YarnClient yarnClient, CommandLine cliParser) {
    this.conf = conf;
    this.yarnClient = yarnClient;
    appName = cliParser.getOptionValue(
        Constants.OPT_TF_APP_NAME, Constants.DEFAULT_APP_NAME);

    amMemory = Integer.parseInt(cliParser.getOptionValue(
        Constants.OPT_TF_APP_MASTER_MEMORY, Constants.DEFAULT_APP_MASTER_MEMORY));
    amVCores = Integer.parseInt(cliParser.getOptionValue(
        Constants.OPT_TF_APP_MASTER_VCORES, Constants.DEFAULT_APP_MASTER_VCORES));
    amQueue = cliParser.getOptionValue(
        Constants.OPT_TF_APP_MASTER_QUEUE, Constants.DEFAULT_APP_MASTER_QUEUE);
    containerMemory = Integer.parseInt(cliParser.getOptionValue(
        Constants.OPT_TF_CONTAINER_MEMORY, Constants.DEFAULT_CONTAINER_MEMORY));
    containerVCores = Integer.parseInt(cliParser.getOptionValue(
        Constants.OPT_TF_CONTAINER_VCORES, Constants.DEFAULT_CONTAINER_VCORES));

    if (cliParser.hasOption(Constants.OPT_TF_JAR)) {
      tfJar = cliParser.getOptionValue(Constants.OPT_TF_JAR);
    } else {
      tfJar = ClassUtil.findContainingJar(getClass());
    }

    if (cliParser.hasOption(Constants.OPT_TF_LIB)) {
      tfLib = cliParser.getOptionValue(Constants.OPT_TF_LIB);
    } else {
      tfLib = Utils.getParentDir(tfJar) + File.separator + Constants.TF_LIB_NAME;
    }

    workerNum = Integer.parseInt(
        cliParser.getOptionValue(Constants.OPT_TF_WORKER_NUM, Constants.DEFAULT_TF_WORKER_NUM));

    if (workerNum <= 0) {
      throw new IllegalArgumentException(
          "Illegal number of TensorFlow worker task specified: " + workerNum);
    }

    psNum = Integer.parseInt(
        cliParser.getOptionValue(Constants.OPT_TF_PS_NUM, Constants.DEFAULT_TF_PS_NUM));

    if (psNum < 0) {
      throw new IllegalArgumentException(
          "Illegal number of TensorFlow ps task specified: " + psNum);
    }
  }

  public boolean run() throws Exception {
    YarnClientApplication app = createApplication();
    ApplicationId appId = app.getNewApplicationResponse().getApplicationId();

    // Copy the application jar to the filesystem
    FileSystem fs = FileSystem.get(conf);
    String appIdStr = appId.toString();
    Path dstJarPath = Utils.copyLocalFileToDfs(fs, appIdStr, new Path(tfJar), Constants.TF_JAR_NAME);
    Path dstLibPath = Utils.copyLocalFileToDfs(fs, appIdStr, new Path(tfLib),
        Constants.TF_LIB_NAME);
    Map<String, Path> files = new HashMap<>();
    files.put(Constants.TF_JAR_NAME, dstJarPath);
    Map<String, LocalResource> localResources = Utils.makeLocalResources(fs, files);
    Map<String, String> javaEnv = Utils.setJavaEnv(conf);
    String command = makeAppMasterCommand(dstLibPath.toString(), dstJarPath.toString());
    LOG.info("Make ApplicationMaster command: " + command);
    ContainerLaunchContext launchContext = ContainerLaunchContext.newInstance(
        localResources, javaEnv, Lists.newArrayList(command), null, null, null);
    Resource resource = Resource.newInstance(amMemory, amVCores);
    submitApplication(app, appName, launchContext, resource, amQueue);
    return awaitApplication(appId);
  }

  YarnClientApplication createApplication() throws Exception {
    return yarnClient.createApplication();
  }

  ApplicationId submitApplication(
      YarnClientApplication app,
      String appName,
      ContainerLaunchContext launchContext,
      Resource resource,
      String queue) throws Exception {
    ApplicationSubmissionContext appContext = app.getApplicationSubmissionContext();
    appContext.setApplicationName(appName);
    appContext.setApplicationTags(new HashSet<>());
    appContext.setAMContainerSpec(launchContext);
    appContext.setResource(resource);
    appContext.setQueue(queue);

    return yarnClient.submitApplication(appContext);
  }

  boolean awaitApplication(ApplicationId appId) throws Exception {
    Set<YarnApplicationState> terminated = Sets.newHashSet(
        YarnApplicationState.FAILED,
        YarnApplicationState.FINISHED,
        YarnApplicationState.KILLED);
    while (true) {
      ApplicationReport report = yarnClient.getApplicationReport(appId);
      YarnApplicationState state = report.getYarnApplicationState();
      if (state.equals(YarnApplicationState.RUNNING)) {
        ClusterSpec clusterSpec = Client.getClusterSpec(yarnClient, appId);
        if (isClusterSpecSatisfied(clusterSpec)) {
          System.out.println("ClusterSpec: " + Utils.toJsonString(clusterSpec.getCluster()));
          return true;
        }
      } else if (terminated.contains(state)) {
        return false;
      } else {
        Thread.sleep(1000);
      }
    }
  }

  private String makeAppMasterCommand(String tfLib, String tfJar) {
    String[] commands = new String[]{
        ApplicationConstants.Environment.JAVA_HOME.$$() + "/bin/java",
        // Set Xmx based on am memory size
        "-Xmx" + amMemory + "m",
        // Set class name
        ApplicationMaster.class.getName(),
        Utils.mkOption(Constants.OPT_TF_CONTAINER_MEMORY, containerMemory),
        Utils.mkOption(Constants.OPT_TF_CONTAINER_VCORES, containerVCores),
        Utils.mkOption(Constants.OPT_TF_WORKER_NUM, workerNum),
        Utils.mkOption(Constants.OPT_TF_PS_NUM, psNum),
        Utils.mkOption(Constants.OPT_TF_LIB, tfLib),
        Utils.mkOption(Constants.OPT_TF_JAR, tfJar),
        "1>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stdout",
        "2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stderr"
    };
    return Utils.mkString(commands, " ");
  }

  private boolean isClusterSpecSatisfied(ClusterSpec clusterSpec) {
    List<String> worker = clusterSpec.getWorker();
    List<String> ps = clusterSpec.getPs();

    return worker != null && worker.size() == workerNum &&
        ps != null && ps.size() == psNum;
  }
}