/* * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. */ package com.amazonaws.services.sagemaker.sparksdk.transformation import java.io.{File, FileWriter} import collection.JavaConverters._ import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.LibSVMResponseRowDeserializer import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.LibSVMRequestRowSerializer class LibSVMTransformationLocalFunctionalTests extends FlatSpec with Matchers with MockitoSugar with BeforeAndAfter { val spark = SparkSession.builder .master("local") .appName("spark session") .getOrCreate() var libsvmDataFile : File = _ val libsvmdata = "1.0 1:1.5 2:3.0 28:-39.935 55:0.01\n" + "0.0 2:3.0 28:-39.935 55:0.01\n" + "-1.0 23:-39.935 55:0.01\n" + "3.0 1:1.5 2:3.0" before { libsvmDataFile = File.createTempFile("temp", "temp") val fw = new FileWriter(libsvmDataFile) fw.write(libsvmdata) fw.close() } "LibSVMSerialization" should "serialize Spark loaded libsvm file to same contents" in { import spark.implicits._ val df = spark.read.format("libsvm").load(libsvmDataFile.getPath) val libsvmSerializer = new LibSVMRequestRowSerializer(Some(df.schema)) val result = df.map(row => new String(libsvmSerializer.serializeRow(row))).collect().mkString assert (libsvmdata.trim == result.trim) } "LibSVMDeserialization" should "deserialize serialized lib svm records" in { val libsvmdata = "1.0 1:1.5 2:3.0 28:-39.935 55:0.01\n" + "0.0 2:3.0 28:-39.935 55:0.01\n" + "-1.0 23:-39.935 55:0.01\n" + "3.0 1:1.5 2:3.0" val libsvmDeserializer = new LibSVMResponseRowDeserializer (55) val rowList = libsvmDeserializer.deserializeResponse(libsvmdata.getBytes).toBuffer.asJava val deserializedDataFrame = spark.createDataFrame(rowList, libsvmDeserializer.schema) val sparkProducedDataFrame = spark.read.format("libsvm").load(libsvmDataFile.getPath) val deserializedRows = deserializedDataFrame.collectAsList() val sparkRows = sparkProducedDataFrame.collectAsList() assert (deserializedRows == sparkRows) } }