org.apache.spark.mllib.tree.impurity.ImpurityCalculator Scala Examples

The following examples show how to use org.apache.spark.mllib.tree.impurity.ImpurityCalculator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SparkNodeWrapper.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.tree.decision

import ml.bundle.dtree.Node
import ml.bundle.dtree.Node.{InternalNode, LeafNode}
import ml.bundle.dtree.Split
import ml.bundle.dtree.Split.{CategoricalSplit, ContinuousSplit}
import ml.combust.bundle.tree.decision.NodeWrapper
import org.apache.spark.ml.tree
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator


object SparkNodeWrapper extends NodeWrapper[tree.Node] {
  override def node(node: tree.Node, withImpurities: Boolean): Node = node match {
    case node: tree.InternalNode =>
      val split = node.split match {
        case split: tree.CategoricalSplit =>
          val left = split.leftCategories
          val right = split.rightCategories
          val (isLeft, categories) = if(left.length < right.length) {
            (true, left)
          } else {
            (false, right)
          }
          Split(Split.S.Categorical(CategoricalSplit(featureIndex = split.featureIndex,
            isLeft = isLeft,
            numCategories = split.numCategories,
            categories = categories)))
        case split: tree.ContinuousSplit =>
          Split(Split.S.Continuous(ContinuousSplit(featureIndex = split.featureIndex,
            threshold = split.threshold)))
      }
      Node(Node.N.Internal(Node.InternalNode(Some(split))))
    case node: tree.LeafNode =>
      val values = if(withImpurities) {
        node.impurityStats.stats.toSeq
      } else { Seq(node.prediction) }
      Node(Node.N.Leaf(Node.LeafNode(values)))
  }

  override def isInternal(node: tree.Node): Boolean = node.isInstanceOf[tree.InternalNode]

  override def leaf(node: LeafNode, withImpurities: Boolean): tree.Node = {
    val calc: ImpurityCalculator = if(withImpurities) {
      ImpurityCalculator.getCalculator("gini", node.values.toArray)
    } else {
      null
    }
    new tree.LeafNode(prediction = node.values.max,
      impurity = 0.0,
      impurityStats = calc)
  }

  override def internal(node: InternalNode,
                        left: tree.Node,
                        right: tree.Node): tree.Node = {
    val bundleSplit = node.split.get
    val split = if(bundleSplit.s.isCategorical) {
      val s = bundleSplit.getCategorical
      val c = if(s.isLeft) {
        s.categories.toArray
      } else {
        ((0 until s.numCategories).map(_.toDouble).toSet -- s.categories).toArray
      }
      new tree.CategoricalSplit(featureIndex = s.featureIndex,
        numCategories = s.numCategories,
        _leftCategories = c)
    } else if(bundleSplit.s.isContinuous) {
      val s = bundleSplit.getContinuous
      new tree.ContinuousSplit(featureIndex = s.featureIndex,
        threshold = s.threshold)
    } else { throw new IllegalArgumentException("invalid split") }

    new tree.InternalNode(split = split,
      leftChild = left,
      rightChild = right,
      prediction = 0.0,
      gain = 0.0,
      impurity = 0.0,
      impurityStats = null)
  }

  override def left(node: tree.Node): tree.Node = node match {
    case node: tree.InternalNode => node.leftChild
    case _ => throw new IllegalArgumentException("not an internal node")
  }

  override def right(node: tree.Node): tree.Node = node match {
    case node: tree.InternalNode => node.rightChild
    case _ => throw new IllegalArgumentException("not an internal node")
  }
}