package org.apache.spark.ml.bundle.ops.classification import ml.combust.bundle.BundleContext import ml.combust.bundle.op.OpModel import ml.combust.bundle.dsl._ import ml.combust.mleap.tensor.DenseTensor import org.apache.spark.ml.bundle.{ParamSpec, SimpleParamSpec, SimpleSparkOp, SparkBundleContext} import org.apache.spark.ml.classification.LogisticRegressionModel import org.apache.spark.ml.linalg.{Matrices, Vectors} /** * Created by hollinwilkins on 8/21/16. */ class LogisticRegressionOp extends SimpleSparkOp[LogisticRegressionModel] { private final val LOGISTIC_REGRESSION_DEFAULT_THRESHOLD = 0.5 override val Model: OpModel[SparkBundleContext, LogisticRegressionModel] = new OpModel[SparkBundleContext, LogisticRegressionModel] { override val klazz: Class[LogisticRegressionModel] = classOf[LogisticRegressionModel] override def opName: String = Bundle.BuiltinOps.classification.logistic_regression override def store(model: Model, obj: LogisticRegressionModel) (implicit context: BundleContext[SparkBundleContext]): Model = { val m = model.withValue("num_classes", Value.long(obj.numClasses)) if(obj.numClasses > 2) { val cm = obj.coefficientMatrix val thresholds = if(obj.isSet(obj.thresholds)) { Some(obj.getThresholds) } else None m.withValue("coefficient_matrix", Value.tensor[Double](DenseTensor(cm.toArray, Seq(cm.numRows, cm.numCols)))). withValue("intercept_vector", Value.vector(obj.interceptVector.toArray)). withValue("thresholds", thresholds.map(_.toSeq).map(Value.doubleList)) } else { m.withValue("coefficients", Value.vector(obj.coefficients.toArray)). withValue("intercept", Value.double(obj.intercept)). withValue("threshold", Value.double(obj.getThreshold)) } } override def load(model: Model) (implicit context: BundleContext[SparkBundleContext]): LogisticRegressionModel = { val numClasses = model.value("num_classes").getLong val r = if(numClasses > 2) { val cmTensor = model.value("coefficient_matrix").getTensor[Double] val coefficientMatrix = Matrices.dense(cmTensor.dimensions.head, cmTensor.dimensions(1), cmTensor.toArray) val lr = new LogisticRegressionModel(uid = "", coefficientMatrix = coefficientMatrix, interceptVector = Vectors.dense(model.value("intercept_vector").getTensor[Double].toArray), numClasses = numClasses.toInt, isMultinomial = true) model.getValue("thresholds"). map(t => lr.setThresholds(t.getDoubleList.toArray)). getOrElse(lr) } else { val lr = new LogisticRegressionModel(uid = "", coefficients = Vectors.dense(model.value("coefficients").getTensor[Double].toArray), intercept = model.value("intercept").getDouble) // default threshold is 0.5 for both Spark and Scikit-learn val threshold = model.getValue("threshold") .map(value => value.getDouble) .getOrElse(LOGISTIC_REGRESSION_DEFAULT_THRESHOLD) lr.setThreshold(threshold) } r } } override def sparkLoad(uid: String, shape: NodeShape, model: LogisticRegressionModel): LogisticRegressionModel = { val numClasses = model.numClasses val r = if (numClasses > 2) { val lr = new LogisticRegressionModel(uid = uid, coefficientMatrix = model.coefficientMatrix, interceptVector = model.interceptVector, numClasses = numClasses, isMultinomial = true) if(model.isDefined(model.thresholds)) { lr.setThresholds(model.getThresholds) } lr } else { val lr = new LogisticRegressionModel(uid = uid, coefficientMatrix = model.coefficientMatrix, interceptVector = model.interceptVector, numClasses = numClasses, isMultinomial = false) if(model.isDefined(model.threshold)) { lr.setThreshold(model.getThreshold) } lr } r } override def sparkInputs(obj: LogisticRegressionModel): Seq[ParamSpec] = { Seq("features" -> obj.featuresCol) } override def sparkOutputs(obj: LogisticRegressionModel): Seq[SimpleParamSpec] = { Seq("raw_prediction" -> obj.rawPredictionCol, "probability" -> obj.probabilityCol, "prediction" -> obj.predictionCol) } }