/* * Copyright (c) 2017, Salesforce.com, Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * * Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * * Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.FeatureSparkTypes import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.unary.{UnaryEstimator, UnaryModel} import com.salesforce.op.utils.spark.RichRow._ import org.apache.spark.ml.param.DoubleParam import org.apache.spark.sql.Dataset import scala.reflect.runtime.universe._ /** * Fill missing values with mean for any numeric feature */ class FillMissingWithMean[N, I <: OPNumeric[N]] ( uid: String = UID[FillMissingWithMean[_, _]] )(implicit tti: TypeTag[I], ttiv: TypeTag[I#Value]) extends UnaryEstimator[I, RealNN](operationName = "fillWithMean", uid = uid) { val defaultValue = new DoubleParam(this, "defaultValue", "default value to replace the missing ones") set(defaultValue, 0.0) def setDefaultValue(v: Double): this.type = set(defaultValue, v) private implicit val dEncoder = FeatureSparkTypes.featureTypeEncoder[Real] def fitFn(dataset: Dataset[Option[N]]): UnaryModel[I, RealNN] = { val grouped = dataset.map(v => iConvert.ftFactory.newInstance(v).toDouble).groupBy() val mean = grouped.mean().first().getOption[Double](0).getOrElse($(defaultValue)) new FillMissingWithMeanModel[I](mean = mean, operationName = operationName, uid = uid) } } final class FillMissingWithMeanModel[I <: OPNumeric[_]] private[op] ( val mean: Double, operationName: String, uid: String )(implicit tti: TypeTag[I]) extends UnaryModel[I, RealNN](operationName = operationName, uid = uid) { def transformFn: I => RealNN = _.toDouble.getOrElse(mean).toRealNN }