Over a million developers have joined DZone.

Decision Trees in Apache Spark

DZone's Guide to

Decision Trees in Apache Spark

Learn about decision trees, which are models used to map decisions and their possible consequences, and how can be used in Apache Spark.

· AI Zone ·
Free Resource

Did you know that 50- 80% of your enterprise business processes can be automated with AssistEdge?  Identify processes, deploy bots and scale effortlessly with AssistEdge.

A decision tree is a powerful method for classification, prediction, and facilitating decision-making in sequential decision problems. A decision tree is made up of two components:

  1. Decision
  2. Outcome

And a decision tree includes three type of nodes:

  1. Root node: The top node of the tree comprising all the data.
  2. Splitting node: A node that assigns data to a subgroup.
  3. Terminal node: Final decision (outcome).

To reach to an outcome or to get the result, the process starts with the root node. Based on the decision made on the root node, the splitting node is selected. Based on the decision made on the split node, another child split node is selected. This process goes on until we reach the terminal node, and the value of the terminal node is our outcome.

Decision Trees in Apache Spark

It might sound strange that there is no implementation of the decision tree in Apache Spark. Well, technically, there is. In Apache Spark, you can find an implementation of the random forest algorithm in which the number of trees can be specified by a user. So under the hood, Apache Spark calls the random forest with one tree.

In Apache Spark, the decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature space. The tree predicts the same label for each bottom-most (leaf) partition. Each partition is chosen greedily by selecting the best split from a set of possible splits in order to maximize the information gain at a tree node.

Node impurity is a measure of the homogeneity of the labels at the node. The current implementation provides two impurity measures for classification (Gini impurity and entropy).


Stopping Rule

Recursive tree construction is stopped at a node when one of the following conditions is met:

  1. The node depth is equal to the training maxDepth parameter.
  2. No split candidate leads to an information gain greater than minInfoGain
  3. No split candidate produces child nodes which each have at least training minInstancesPerNode instances.

Useful Parameters

  • algo: It can be either classification or regression.
  • numClasses: Number of classification classes. 
  • maxDepth: Maximum depth of a tree in terms of nodes.
  • minInstancesPerNode: For a node to be split further, each of its children must receive at least this number of training instances.
  • minInfoGain: For a node to be split further, the split must improve at least this much.
  • maxBins: Number of bins used when discretizing continuous features.

Preparing Training Data for Decision Trees

You can not directly feed any data to the decision tree. It demands a special type of format to feed to the decision tree. You can use the HashingTF technique to convert the training data to the labeled data so that the decision tree can understand. This process is also known as the standardization of data.

Feeding and Obtaining Results

Once data has been standardized, you can feed the same decision tree algorithm for classification. But before then, you need to split the data for training and testing purposes; to test the accuracy, you need to hold some part of data for testing. You can feed the data like this:

val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

Here, the data is my standardized input data, which I split into a ratio 7:3 for training and testing purposes respectively. We are using a "gini" impurity with a maximum depth of 5.

Once the model is generated, you can try to predict the classification of the other data. But before that, we need to validate the accuracy of the classification for the recently generated model. You can validate the accuracy by computing "test error".

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println("Test Error = " + testErr)

And that's it! You can take a look at a running example here.

Consuming AI in byte sized applications is the best way to transform digitally. #BuiltOnAI, EdgeVerve’s business application, provides you with everything you need to plug & play AI into your enterprise.  Learn more.

ai ,tutorial ,decision trees ,apache spark

Published at DZone with permission of

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}