/* * Copyright 2015-2020 Snowflake Computing * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.snowflake.spark.snowflake import net.snowflake.spark.snowflake.testsuite.ClusterTestSuiteBase import org.slf4j.{Logger, LoggerFactory} import org.apache.spark.sql.SparkSession object ClusterTest { val log: Logger = LoggerFactory.getLogger(getClass) val RemoteMode = "remote" val LocalMode = "local" val TestSuiteSeparator = ";" // Driver function to run the test. def main(args: Array[String]): Unit = { log.info(s"Test Spark Connector: ${net.snowflake.spark.snowflake.Utils.VERSION}") val usage = s"""Two parameters are need: [local | remote] and | testClassNames (using ';' to separate multiple classes) |""".stripMargin log.info(usage) if (args.length < 2) { throw new Exception(s"At least two parameters are need. Usage: $usage") } // Setup Spark session. // local mode is introduced for debugging purpose val runMode = args(0) var sparkSessionBuilder = SparkSession .builder() .appName("Spark SQL basic example") .config("spark.some.config.option", "some-value") if (runMode.equalsIgnoreCase(LocalMode)) { sparkSessionBuilder = sparkSessionBuilder .config("spark.master", "local") } val spark = sparkSessionBuilder.getOrCreate() // Run specified test suites val testSuiteNames = args(1).split(TestSuiteSeparator) for (testSuiteName <- testSuiteNames) { if (!testSuiteName.trim.isEmpty) { // Retrieve commit ID from env. val commitID = scala.util.Properties .envOrElse(TestUtils.GITHUB_SHA, "commit id not set") // val testSuiteName = "net.snowflake.spark.snowflake.testsuite.BasicReadWriteSuite" val resultBuilder = new ClusterTestResultBuilder() .withTestType("Scala") .withTestCaseName(testSuiteName) .withCommitID(commitID) .withTestStatus(TestUtils.TEST_RESULT_STATUS_INIT) .withStartTimeInMill(System.currentTimeMillis()) .withGithubRunId(TestUtils.githubRunId) try { Class .forName(testSuiteName) .newInstance() .asInstanceOf[ClusterTestSuiteBase] .run(spark, resultBuilder) } catch { case e: Throwable => log.error(e.getMessage) resultBuilder .withTestStatus(TestUtils.TEST_RESULT_STATUS_EXCEPTION) .withReason(e.getMessage) } finally { // Set test end time. resultBuilder .withEndTimeInMill(System.currentTimeMillis()) // Write test result resultBuilder.build().writeToSnowflake() } } } spark.stop() } }