Churn Prediction With Apache Spark Machine Learning

DZone 's Guide to

Churn Prediction With Apache Spark Machine Learning

Learn how to get started using Apache Spark’s machine learning decision trees and machine learning pipelines for classification.

· AI Zone ·
Free Resource

Churn prediction is big business. It minimizes customer defection by predicting which customers are likely to cancel a subscription to a service. Though originally used within the telecommunications industry, it has become common practice across banks, ISPs, insurance firms, and other verticals.

The prediction process is heavily data-driven and often utilizes advanced machine learning techniques. In this post, we'll take a look at what types of customer data are typically used, do some preliminary analysis of the data, and generate churn prediction models — all with Spark and its machine learning frameworks.

Using data science to better understand and predict customer behavior is an iterative process involving:

  1. Data discovery and model creation:
    • Analyzing historical data.
    • Identifying new data sources that traditional analytics or databases aren't using due to the format, size, or structure.
    • Collecting, correlating, and analyzing data across multiple data sources.
    • Knowing and applying the right kind of machine learning algorithms to get value out of the data.
  2. Using the model in production to make predictions.
  3. Data discovery and updating the model with new data.


In order to understand the customer, a number of factors can be analyzed, such as:

  • Customer demographic data (age, marital status, etc.).
  • Sentiment analysis of social media.
  • Customer usage patterns and geographical usage trends.
  • Calling-circle data.
  • Browsing behavior from clickstream logs.
  • Support call center statistics.
  • Historical data that show patterns of behavior that suggest churn.

With this analysis, telecom companies can gain insights to predict and enhance the customer experience, prevent churn, and tailor marketing campaigns.


Classification is a family of supervised machine learning algorithms that identify which category an item belongs to (i.e. whether a transaction is fraudulent) based on labeled examples of known items (i.e. transactions known to be fraud or not). Classification takes a set of data with known labels and pre-determined features and learns how to label new records based on that information. Features are the “if questions” that you ask. The label is the answer to those questions. In the example below, if it walks, swims, and quacks like a duck, then the label is “duck.”


Let’s go through an example of telecom customer churn:

  • What are we trying to predict?
    • Whether a customer has a high probability of unsubscribing from the service.
    • Churn is labeled "true" or "false."
  • What are the “if questions” or properties that you can use to make predictions?
    • Call statistics, customer service calls, etc.
    • To build a classifier model, you extract the features of interest that most contribute to the classification.

Decision Trees

Decision trees create a model that predicts the class or label based on several input features. Decision trees work by evaluating an expression containing a feature at every node and selecting a branch to the next node based on the answer. A possible decision tree for predicting credit risk is shown below. The feature questions are the nodes, and the answers “yes” or “no” are the branches in the tree to the child nodes.

  • Q1: Is checking account balance > 200DM?
    • No
  • Q2: Is length of current employment > 1 year?
    • No
    • Not creditable


Example Use Case Data Set

For this tutorial, we'll be using the Orange Telecoms churn dataset. It consists of cleaned customer activity data (features) and a churn label specifying whether the customer canceled the subscription. The data can be fetched from BigML's S3 bucket, churn-80, and churn-20. The two sets are from the same batch but have been split by an 80/20 ratio. We'll use the larger set for training and cross-validation purposes and the smaller set for final testing and model performance evaluation. The two data sets have been included with the complete code in this repository for convenience. The data set has the following schema:

1. State: string
2. Account length: integer
3. Area code: integer
4. International plan: string
5. Voice mail plan: string
6. Number vmail messages: integer
7. Total day minutes: double
8. Total day calls: integer
9. Total day charge: double
10.Total eve minutes: double
11. Total eve calls: integer
12. Total eve charge: double
13. Total night minutes: double
14. Total night calls: integer
15. Total night charge: double
16. Total intl minutes: double
17. Total intl calls: integer
18. Total intl charge: double
19. Customer service calls: integer

The CSV file has the following format:


The image below shows the first few rows of the data set:



This tutorial will run on Spark 2.0.1 and above.

  • You can download the code and data to run these examples from here.
  • The examples in this post can be run in the Spark shell, after launching with the spark-shell command.
  • You can also run the code as a stand-alone application, as described in the tutorial on getting started with Spark on MapR sandbox. Log into the MapR Sandbox using userid user01, password mapr. Copy the sample data file to your sandbox home directory /user/user01 using scp. Start the Spark shell with:
    $spark-shell --master local[1]

Load the Data From a CSV File


First, we will import the SQL and machine learning packages.

import org.apache.spark._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.feature.VectorAssembler

We use a Scala case class and Structype to define the schema, corresponding to a line in the CSV data file.

// define the Churn Schema
case class Account(state: String, len: Integer, acode: String,
    intlplan: String, vplan: String, numvmail: Double,
    tdmins: Double, tdcalls: Double, tdcharge: Double,
    temins: Double, tecalls: Double, techarge: Double,
    tnmins: Double, tncalls: Double, tncharge: Double,
    timins: Double, ticalls: Double, ticharge: Double,
    numcs: Double, churn: String)

val schema = StructType(Array(
    StructField("state", StringType, true),
    StructField("len", IntegerType, true),
    StructField("acode", StringType, true),
    StructField("intlplan", StringType, true),
    StructField("vplan", StringType, true),
    StructField("numvmail", DoubleType, true),
    StructField("tdmins", DoubleType, true),
    StructField("tdcalls", DoubleType, true),
    StructField("tdcharge", DoubleType, true),
    StructField("temins", DoubleType, true),
    StructField("tecalls", DoubleType, true),
    StructField("techarge", DoubleType, true),
    StructField("tnmins", DoubleType, true),
    StructField("tncalls", DoubleType, true),
    StructField("tncharge", DoubleType, true),
    StructField("timins", DoubleType, true),
    StructField("ticalls", DoubleType, true),
    StructField("ticharge", DoubleType, true),
    StructField("numcs", DoubleType, true),
    StructField("churn", StringType, true)

Using Spark 2.0, we specify the data source and schema to load into a dataset. Note that with Spark 2.0, specifying the schema when loading data into a DataFrame will give better performance than schema inference. We cache the datasets for quick, repeated access. We also print the schema of the datasets.

val train: Dataset[Account] = spark.read.option("inferSchema", "false")

val test: Dataset[Account] = spark.read.option("inferSchema", "false")
 |-- state: string (nullable = true)
 |-- len: integer (nullable = true)
 |-- acode: string (nullable = true)
 |-- intlplan: string (nullable = true)
 |-- vplan: string (nullable = true)
 |-- numvmail: double (nullable = true)
 |-- tdmins: double (nullable = true)
 |-- tdcalls: double (nullable = true)
 |-- tdcharge: double (nullable = true)
 |-- temins: double (nullable = true)
 |-- tecalls: double (nullable = true)
 |-- techarge: double (nullable = true)
 |-- tnmins: double (nullable = true)
 |-- tncalls: double (nullable = true)
 |-- tncharge: double (nullable = true)
 |-- timins: double (nullable = true)
 |-- ticalls: double (nullable = true)
 |-- ticharge: double (nullable = true)
 |-- numcs: double (nullable = true)
 |-- churn: string (nullable = true)

 //display the first 20 rows:


Summary Statistics

Spark DataFrames include some built-in functions for statistical processing. The describe() function performs summary statistics calculations on all numeric columns and returns them as a DataFrame.




Data Exploration

We can use Spark SQL to explore the dataset. Here are some example queries using the Scala DataFrame API:

|False|    3310.0|
| True|     856.0|



Total day minutes and Total day charge are highly correlated fields. Such correlated data won't be very beneficial for our model training runs, so we're going to remove them. We'll do so by dropping one column of each pair of correlated fields, along with the State and Area code columns, which we also won’t use.

val dtrain =train.drop("state").drop("acode").drop("vplan")


Grouping the data by the churn field and counting the number of instances in each group shows that there are roughly six times as many false churn samples as true churn samples.



|False| 2278|
| True|  388|

Business decisions will be used to retain the customers most likely to leave, not those who are likely to stay. Thus, we need to ensure that our model is sensitive to the Churn=True samples.

Stratified Sampling

We can put the two sample types on the same footing using stratified sampling. The DataFrames sampleBy() function does this when provided with fractions of each sample type to be returned. Here, we're keeping all instances of the Churn=True class, but downsampling the Churn=False class to a fraction of 388/2278.

val fractions = Map("False" -> .17, "True" -> 1.0)
val strain = dtrain.stat.sampleBy("churn", fractions, 36L)


|False|  379|
| True|  388|

Features Array

To build a classifier model, you extract the features that most contribute to the classification. The features for each item consist of the fields shown below:

  • Label — churn: True or False
  • Features — {"len", "iplanIndex", "numvmail", "tdmins", "tdcalls", "temins", "tecalls", "tnmins", "tncalls", "timins", "ticalls", "numcs"}

In order for the features to be used by a machine learning algorithm, they are transformed and put into Feature Vectors, which are vectors of numbers representing the value for each feature.

Picture10Reference: Learning Spark

Using the Spark ML Package

The ML package is the newer library of machine learning routines. Spark ML provides a uniform set of high-level APIs built on top of DataFrames.


We will use an ML Pipeline to pass the data through transformers in order to extract the features and an estimator to produce the model.

  • Transformer: An algorithm that transforms one DataFrame into another DataFrame. We will use a transformer to get a DataFrame with a features vector column.
  • Estimator: An algorithm that can be fit on a DataFrame to produce a transformer (for example, training/tuning on a DataFrame and producing a model).
  • Pipeline: Chains multiple transformers and estimators together to specify an ML workflow.

Feature Extraction and Pipelining

The ML package needs data to be put in a (label: Double, features: Vector) DataFrame format with correspondingly named fields. We set up a pipeline to pass the data through three transformers in order to extract the features: two StringIndexers and a VectorAssembler. We use the StringIndexers to convert the String Categorial feature intlplan and label into number indices. Indexing categorical features allows decision trees to treat categorical features appropriately, improving performance.


// set up StringIndexer transformers for label and string feature
val ipindexer = new StringIndexer()
val labelindexer = new StringIndexer()

The VectorAssembler combines a given list of columns into a single feature vector column.

// set up a VectorAssembler transformer
val featureCols = Array("len", "iplanIndex", "numvmail", "tdmins",
     "tdcalls", "temins", "tecalls", "tnmins", "tncalls", "timins",
     "ticalls", "numcs")

val assembler = new VectorAssembler()

The final element in our pipeline is an estimator (a decision tree classifier), training on the vector of labels and features.


// set up a DecisionTreeClassifier estimator
val dTree = new DecisionTreeClassifier().setLabelCol("label")

// Chain indexers and tree in a Pipeline
val pipeline = new Pipeline()
      .setStages(Array(ipindexer, labelindexer, assembler, dTree))

Train the Model


We would like to determine which parameter values of the decision tree produce the best model. A common technique for model selection is k-fold cross validation, where the data is randomly split into k partitions. Each partition is used once as the testing data set, while the rest are used for training. Models are then generated using the training sets and evaluated with the testing sets, resulting in k model performance measurements. The average of the performance scores is often taken to be the overall score of the model, given its build parameters. For model selection we can search through the model parameters, comparing their cross validation performances. The model parameters leading to the highest performance metric produce the best model.

Spark ML supports k-fold cross validation with a transformation/estimation pipeline to try out different combinations of parameters, using a process called grid search, where you set up the parameters to test, and a cross validation evaluator to construct a model selection workflow.

Below, we use a  aramGridBuilder to construct the parameter grid.

// Search through decision tree's maxDepth parameter for best model
val paramGrid = new ParamGridBuilder().addGrid(dTree.maxDepth,
Array(2, 3, 4, 5, 6, 7)).build()

We define a BinaryClassificationEvaluator evaluator, which will evaluate the model according to a precision metric by comparing the test label column with the test prediction column. The default metric is the area under the ROC curve.

// Set up Evaluator (prediction, true label)
val evaluator = new BinaryClassificationEvaluator()

We use a CrossValidator for model selection. The CrossValidator uses the estimator pipeline, the parameter grid, and the classification evaluator. The CrossValidator uses the ParamGridBuilder to iterate through the maxDepth parameter of the decision tree and evaluate the models, repeating three times per parameter value for reliable results.

// Set up 3-fold cross validation
 val crossval = new CrossValidator().setEstimator(pipeline)

val cvModel = crossval.fit(ntrain)

We get the best decision tree model, in order to print out the decision tree and parameters.

// Fetch best model
val bestModel = cvModel.bestModel
val treeModel = bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)



//0-11 feature columns: len, iplanIndex, numvmail, tdmins, tdcalls, temins, tecalls, tnmins, tncalls, timins, ticalls, numcs
println( "Feature 11:" +  featureCols(11))
println( "Feature 3:" +  featureCols(3))

Feature 11:numcs
Feature 3:tdmins

We find that the best tree model produced using the cross-validation process is one with a depth of 5. The toDebugString() function provides a print of the tree's decision nodes and final prediction outcomes at the end leaves. We can see that features 11 and 3 are used for decision making and should thus be considered as having high predictive power to determine a customer's likeliness to churn. It's not surprising that these feature numbers map to the fields Customer service calls and Total day minutes. Decision trees are often used for feature selection because they provide an automated mechanism for determining the most important features (those closest to the tree root).

Predictions and Model Evaluation


The actual performance of the model can be determined using the test data set that has not been used for any training or cross-validation activities. We'll transform the test set with the model pipeline, which will map the features according to the same recipe.

val predictions = cvModel.transform(test)


The evaluator will provide us with the score of the predictions, and then we'll print them along with their probabilities.

val accuracy = evaluator.evaluate(predictions)
val result = predictions.select("label", "prediction", "probability")


accuracy: Double = 0.8484817813765183
metric name in evaluation (default: areaUnderROC)


In this case, the evaluation returns 84.8% precision. The prediction probabilities can be very useful in ranking customers by their likeliness to defect. This way, the limited resources available to the business for retention can be focused on the appropriate customers.

Below, we calculate some more metrics. The number of false/true positive and negative predictions is also useful:

  • True positives are how often the model correctly predicted subscription canceling.
  • False positives are how often the model incorrectly predicted subscription canceling.
  • True negatives indicate how often the model correctly predicted no canceling.
  • False negatives indicate how often the model incorrectly predicted no canceling.
val lp = predictions.select("label", "prediction")
val counttotal = predictions.count()
val correct = lp.filter($"label" === $"prediction").count()
val wrong = lp.filter(not($"label" === $"prediction")).count()
val ratioWrong = wrong.toDouble / counttotal.toDouble
val ratioCorrect = correct.toDouble / counttotal.toDouble
val truep = lp.filter($"prediction" === 0.0)
.filter($"label" === $"prediction").count() / counttotal.toDouble
val truen = lp.filter($"prediction" === 1.0)
.filter($"label" === $"prediction").count() / counttotal.toDouble
val falsep = lp.filter($"prediction" === 1.0)
.filter(not($"label" === $"prediction")).count() / counttotal.toDouble
val falsen = lp.filter($"prediction" === 0.0)
.filter(not($"label" === $"prediction")).count() / counttotal.toDouble

println("counttotal : " + counttotal)
println("correct : " + correct)
println("wrong: " + wrong)
println("ratio wrong: " + ratioWrong)
println("ratio correct: " + ratioCorrect)
println("ratio true positive : " + truep)
println("ratio false positive : " + falsep)
println("ratio true negative : " + truen)
println("ratio false negative : " + falsen)


Want to Learn More?

In this blog post, we showed you how to get started using Apache Spark’s machine learning decision trees and ML pipelines for classification. If you have any further questions about this tutorial, please ask them in the comments section below.

ai ,apache spark ,predictive analytics ,churn ,machine learning ,tutorial ,classification

Published at DZone with permission of Carol McDonald , DZone MVB. See the original article here.

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}