/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.submarine.client.cli.runjob.tensorflow;

import org.apache.commons.cli.ParseException;
import org.apache.submarine.client.cli.CliConstants;
import org.apache.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.submarine.client.cli.runjob.RunJobCli;
import org.apache.submarine.commons.runtime.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.apache.submarine.client.cli.runjob.RunJobCliParsingCommonTest;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

/**
 * Test class that verifies the correctness of TensorFlow
 * CLI configuration parsing.
 */
public class RunJobCliParsingTensorFlowTest {

  @Before
  public void before() {
    SubmarineLogs.verboseOff();
  }

  @Rule
  public ExpectedException expectedException = ExpectedException.none();

  @Test
  public void testNoInputPathOptionSpecified() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
        "\" is absent";
    String actualMessage = "";
    try {
      runJobCli.run(
          new String[]{"--framework", "tensorflow",
              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
              "--checkpoint_path", "hdfs://output",
              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
              "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
              "true", "--verbose", "--wait_job_finish"});
    } catch (ParseException e) {
      actualMessage = e.getMessage();
      e.printStackTrace();
    }
    assertEquals(expectedErrorMessage, actualMessage);
  }

  @Test
  public void testBasicRunJobForDistributedTraining() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());

    assertFalse(SubmarineLogs.isVerbose());

    runJobCli.run(
        new String[] { "--framework", "tensorflow",
            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
            "--input_path", "hdfs://input",
            "--checkpoint_path", "hdfs://output",
            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
            "--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
            "--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
            "--principal", "user/[email protected]", "--distribute_keytab",
            "--verbose" });

    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
    assertTrue(RunJobParameters.class +
        " must be an instance of " +
        TensorFlowRunJobParameters.class,
        jobRunParameters instanceof TensorFlowRunJobParameters);
    TensorFlowRunJobParameters tensorFlowParams =
        (TensorFlowRunJobParameters) jobRunParameters;

    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
    assertEquals(tensorFlowParams.getNumPS(), 2);
    assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
    assertEquals(Resources.createResource(4096, 4),
        tensorFlowParams.getPsResource());
    assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
        "python run-job.py");
    assertEquals(Resources.createResource(2048, 2),
        tensorFlowParams.getWorkerResource());
    assertEquals(jobRunParameters.getDockerImageName(),
        "tf-docker:1.1.0");
    assertEquals(jobRunParameters.getKeytab(),
        "/keytab/path");
    assertEquals(jobRunParameters.getPrincipal(),
        "user/[email protected]");
    assertTrue(jobRunParameters.isDistributeKeytab());
    assertTrue(SubmarineLogs.isVerbose());
  }

  @Test
  public void testBasicRunJobForSingleNodeTraining() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    assertFalse(SubmarineLogs.isVerbose());

    runJobCli.run(
        new String[] { "--framework", "tensorflow",
            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
            "--input_path", "hdfs://input", "--checkpoint_path",
            "hdfs://output",
            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
            "true", "--verbose", "--wait_job_finish" });

    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
    assertTrue(RunJobParameters.class +
            " must be an instance of " +
            TensorFlowRunJobParameters.class,
        jobRunParameters instanceof TensorFlowRunJobParameters);
    TensorFlowRunJobParameters tensorFlowParams =
        (TensorFlowRunJobParameters) jobRunParameters;

    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
    assertEquals(tensorFlowParams.getNumWorkers(), 1);
    assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
        "python run-job.py");
    assertEquals(Resources.createResource(4096, 2),
        tensorFlowParams.getWorkerResource());
    assertTrue(SubmarineLogs.isVerbose());
    assertTrue(jobRunParameters.isWaitJobFinish());
  }

  /**
   * when only run tensorboard, input_path is not needed
   * */
  @Test
  public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    boolean success = true;
    try {
      runJobCli.run(
          new String[]{"--framework", "tensorflow",
              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
              "--num_workers", "0", "--tensorboard", "--verbose",
              "--tensorboard_resources", "memory=2G,vcores=2",
              "--tensorboard_docker_image", "tb_docker_image:001"});
    } catch (ParseException e) {
      success = false;
    }
    assertTrue(success);
  }

  @Test
  public void testNumSchedulerCannotBeDefined() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    assertFalse(SubmarineLogs.isVerbose());

    expectedException.expect(ParseException.class);
    expectedException.expectMessage("cannot be defined for TensorFlow jobs");
    runJobCli.run(
        new String[] {"--framework", "tensorflow",
            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
            "--worker_resources", "memory=4g,vcores=2", "--tensorboard", "true",
            "--verbose", "--wait_job_finish", "--num_schedulers", "1"});
  }

  @Test
  public void testSchedulerResourcesCannotBeDefined() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    assertFalse(SubmarineLogs.isVerbose());

    expectedException.expect(ParseException.class);
    expectedException.expectMessage("cannot be defined for TensorFlow jobs");
    runJobCli.run(
        new String[] {"--framework", "tensorflow",
            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
            "--worker_resources", "memory=4g,vcores=2", "--tensorboard", "true",
            "--verbose", "--wait_job_finish",
            "--scheduler_resources", "memory=2048M,vcores=2"});
  }

  @Test
  public void testSchedulerDockerImageCannotBeDefined() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    assertFalse(SubmarineLogs.isVerbose());

    expectedException.expect(ParseException.class);
    expectedException.expectMessage("cannot be defined for TensorFlow jobs");
    runJobCli.run(
        new String[] {"--framework", "tensorflow",
            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
            "--worker_resources", "memory=4g,vcores=2", "--tensorboard", "true",
            "--verbose", "--wait_job_finish",
            "--scheduler_docker_image", "schedulerDockerImage"});
  }

  @Test
  public void testSchedulerLaunchCommandCannotBeDefined() throws Exception {
    RunJobCli runJobCli = new RunJobCli(RunJobCliParsingCommonTest.getMockClientContext());
    assertFalse(SubmarineLogs.isVerbose());

    expectedException.expect(ParseException.class);
    expectedException.expectMessage("cannot be defined for TensorFlow jobs");
    runJobCli.run(
        new String[] {"--framework", "tensorflow",
            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
            "--worker_resources", "memory=4g,vcores=2", "--tensorboard", "true",
            "--verbose", "--wait_job_finish",
            "--scheduler_launch_cmd", "schedulerLaunchCommand"});
  }
}