/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.submarine.client.cli.param.runjob;

import com.google.common.collect.Lists;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.submarine.client.cli.CliConstants;
import org.apache.submarine.client.cli.CliUtils;
import org.apache.submarine.client.cli.runjob.RoleParameters;
import org.apache.submarine.commons.runtime.ClientContext;
import org.apache.submarine.commons.runtime.param.Parameter;
import org.apache.submarine.commons.runtime.api.TensorFlowRole;
import org.apache.submarine.commons.runtime.resource.ResourceUtils;

import java.io.IOException;
import java.util.List;

/**
 * Parameters for TensorFlow job.
 */
public class TensorFlowRunJobParameters extends RunJobParameters {
  private boolean tensorboardEnabled;
  private static final String CANNOT_BE_DEFINED_FOR_TF =
      "cannot be defined for TensorFlow jobs!";
  private RoleParameters psParameters =
      RoleParameters.createEmpty(TensorFlowRole.PS);
  private RoleParameters tensorBoardParameters =
      RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);

  @Override
  public void updateParameters(Parameter parametersHolder, ClientContext clientContext)
      throws ParseException, IOException, YarnException {
    checkArguments(parametersHolder);
    super.updateParameters(parametersHolder, clientContext);

    String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
    this.workerParameters =
        generateWorkerParameters(clientContext, parametersHolder, input);
    this.psParameters = getPSParameters(clientContext, parametersHolder);
    this.distributed = determineIfDistributed(workerParameters.getReplicas(),
        psParameters.getReplicas());

    if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
      this.tensorboardEnabled = true;
      this.tensorBoardParameters =
          getTensorBoardParameters(parametersHolder, clientContext);
    }
    executePostOperations(clientContext);
  }

  @Override
  void executePostOperations(ClientContext clientContext) throws IOException {
    // Set default job dir / saved model dir, etc.
    setDefaultDirs(clientContext);
    replacePatternsInParameters(clientContext);
  }

  private void checkArguments(Parameter parametersHolder)
      throws YarnException, ParseException {
    if (parametersHolder.getOptionValue(CliConstants.N_SCHEDULERS) != null) {
      throw new ParseException(getParamCannotBeDefinedErrorMessage(
          CliConstants.N_SCHEDULERS));
    } else if (parametersHolder.getOptionValue(CliConstants.SCHEDULER_RES) != null) {
      throw new ParseException(getParamCannotBeDefinedErrorMessage(
          CliConstants.SCHEDULER_RES));
    } else if (parametersHolder
        .getOptionValue(CliConstants.SCHEDULER_DOCKER_IMAGE) != null) {
      throw new ParseException(getParamCannotBeDefinedErrorMessage(
          CliConstants.SCHEDULER_DOCKER_IMAGE));
    } else if (parametersHolder
        .getOptionValue(CliConstants.SCHEDULER_LAUNCH_CMD) != null) {
      throw new ParseException(getParamCannotBeDefinedErrorMessage(
          CliConstants.SCHEDULER_LAUNCH_CMD));
    }
  }

  private String getParamCannotBeDefinedErrorMessage(String cliName) {
    return String.format(
        "Parameter '%s' " + CANNOT_BE_DEFINED_FOR_TF, cliName);
  }

  private void replacePatternsInParameters(ClientContext clientContext)
      throws IOException {
    if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
          getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
      setPSLaunchCmd(afterReplace);
    }

    if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
      String afterReplace =
          CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
              clientContext.getRemoteDirectoryManager());
      setWorkerLaunchCmd(afterReplace);
    }
  }

  @Override
  public List<String> getLaunchCommands() {
    return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
  }

  private boolean determineIfDistributed(int nWorkers, int nPS)
      throws ParseException {
    // Check #workers and #ps.
    // When distributed training is required
    if (nWorkers >= 2 && nPS > 0) {
      return true;
    } else if (nWorkers <= 1 && nPS > 0) {
      throw new ParseException("Only specified one worker but non-zero PS, "
          + "please double check.");
    }
    return false;
  }

  private RoleParameters getPSParameters(ClientContext clientContext,
      Parameter parametersHolder)
      throws YarnException, IOException, ParseException {
    int nPS = getNumberOfPS(parametersHolder);
    Resource psResource =
        determinePSResource(parametersHolder, nPS, clientContext);
    String psDockerImage =
        parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
    String psLaunchCommand =
        parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
    return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
        psDockerImage, psResource);
  }

  private Resource determinePSResource(Parameter parametersHolder,
      int nPS, ClientContext clientContext)
      throws ParseException, YarnException, IOException {
    if (nPS > 0) {
      String psResourceStr =
          parametersHolder.getOptionValue(CliConstants.PS_RES);
      if (psResourceStr == null) {
        throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
      }
      return ResourceUtils.createResourceFromString(psResourceStr);
    }
    return null;
  }

  private int getNumberOfPS(Parameter parametersHolder)
      throws YarnException {
    int nPS = 0;
    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
      nPS =
          Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
    }
    return nPS;
  }

  private RoleParameters getTensorBoardParameters(Parameter parametersHolder,
      ClientContext clientContext) throws YarnException, IOException {
    String tensorboardResourceStr =
        parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
    if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
      tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
    }
    Resource tensorboardResource = ResourceUtils.createResourceFromString(
            tensorboardResourceStr);
    String tensorboardDockerImage =
        parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
    return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
        tensorboardDockerImage, tensorboardResource);
  }

  public RoleParameters getPsParameters() {
    return psParameters;
  }

  public void setPsParameters(RoleParameters parameters) {
    this.psParameters = parameters;
  }

  public int getNumPS() {
    return psParameters.getReplicas();
  }

  public void setNumPS(int numPS) {
    psParameters.setReplicas(numPS);
  }

  public Resource getPsResource() {
    return psParameters.getResource();
  }

  public void setPsResource(Resource resource) {
    psParameters.setResource(resource);
  }

  public String getPsDockerImage() {
    return psParameters.getDockerImage();
  }

  public void setPsDockerImage(String image) {
    psParameters.setDockerImage(image);
  }

  public String getPSLaunchCmd() {
    return psParameters.getLaunchCommand();
  }

  public void setPSLaunchCmd(String launchCmd) {
    psParameters.setLaunchCommand(launchCmd);
  }

  public RoleParameters getTensorBoardParameters() {
    return tensorBoardParameters;
  }

  public void setTensorBoardParameters(RoleParameters tensorBoardParameters) {
    this.tensorBoardParameters = tensorBoardParameters;
  }

  public boolean isTensorboardEnabled() {
    return tensorboardEnabled;
  }

  public void setTensorboardEnabled(boolean tensorboardEnabled) {
    this.tensorboardEnabled = tensorboardEnabled;
  }

  public Resource getTensorboardResource() {
    return tensorBoardParameters.getResource();
  }

  public void setTensorboardResource(Resource resource) {
    tensorBoardParameters.setResource(resource);
  }

  public String getTensorboardDockerImage() {
    return tensorBoardParameters.getDockerImage();
  }

  public void setTensorboardDockerImage(String image) {
    tensorBoardParameters.setDockerImage(image);
  }

}