package com.ibm.aardpfark.spark.ml.feature import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument} import com.ibm.aardpfark.pfa.expression._ import com.ibm.aardpfark.pfa.types.WithSchema import com.ibm.aardpfark.spark.ml.PFAModel import com.sksamuel.avro4s.{AvroNamespace, AvroSchema} import org.apache.avro.SchemaBuilder import org.apache.spark.ml.feature.StandardScalerModel @AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature") case class StandardScalerModelData(mean: Seq[Double], std: Seq[Double]) extends WithSchema { def schema = AvroSchema[this.type] } class PFAStandardScalerModel(override val sparkTransformer: StandardScalerModel) extends PFAModel[StandardScalerModelData] { import com.ibm.aardpfark.pfa.dsl._ import com.ibm.aardpfark.pfa.dsl.core._ private val inputCol = sparkTransformer.getInputCol private val outputCol = sparkTransformer.getOutputCol private val inputExpr = StringExpr(s"input.${inputCol}") // references to cell variables private val meanRef = modelCell.ref("mean") private val stdRef = modelCell.ref("std") override def inputSchema = { SchemaBuilder.record(withUid(inputBaseName)).fields() .name(inputCol).`type`().array().items().doubleType().noDefault() .endRecord() } override def outputSchema = { SchemaBuilder.record(withUid(outputBaseName)).fields() .name(outputCol).`type`().array().items().doubleType().noDefault() .endRecord() } override def cell = { val scalerData = StandardScalerModelData(sparkTransformer.mean.toArray, sparkTransformer.std.toArray) Cell(scalerData) } def partFn(name: String, p: Seq[String], e: PFAExpression) = { NamedFunctionDef(name, FunctionDef[Double, Double](p, Seq(e))) } // function schema val (scaleFnDef, scaleFnRef) = if (sparkTransformer.getWithMean) { if (sparkTransformer.getWithStd) { val meanStdScale = partFn("meanStdScale", Seq("i", "m", "s"), div(minus("i", "m"), "s")) (Some(meanStdScale), a.zipmap(inputExpr, meanRef, stdRef, meanStdScale.ref)) } else { val meanScale = partFn("meanScale", Seq("i", "m"), minus("i", "m")) (Some(meanScale), a.zipmap(inputExpr, meanRef, meanScale.ref)) } } else { if (sparkTransformer.getWithStd) { val stdScale = partFn("stdScale", Seq("i", "s"), div("i", "s")) (Some(stdScale), a.zipmap(inputExpr, stdRef, stdScale.ref)) } else { (None, inputExpr) } } override def action: PFAExpression = { NewRecord(outputSchema, Map(outputCol -> scaleFnRef)) } override def pfa: PFADocument = { val builder = PFABuilder() .withName(sparkTransformer.uid) .withMetadata(getMetadata) .withInput(inputSchema) .withOutput(outputSchema) .withCell(modelCell) .withAction(action) scaleFnDef.foreach(fnDef => builder.withFunction(fnDef)) builder.pfa } }