Over a million developers have joined DZone.

Using Artificial Neural Networks to Predict Emergency Department Deaths

How to use Apache Spark and machine learning to predict hospital fatalities due to heart disease.

· Big Data Zone

Learn how you can maximize big data in the cloud with Apache Hadoop. Download this eBook now. Brought to you in partnership with Hortonworks.

Introduction

Apache Spark is a cluster-based, open-source computing system mainly useful for working with very large data sets. Parallel computing and fault tolerance are built-in features of Spark architecture. Spark Core is the main component of Spark and provides general purpose data processing functions over a cluster of machines. Additional components built on top of Spark Core bring more functionality such as machine learning. Comprehensive documentation for Apache Spark is available, see the official Apache Spark documentationIntroduction to Apache SparkBig Data Processing in Spark and Getting Started With Spark Streaming.

This article focuses on Spark MLlib library, which provides an API for implementing machine learning and statistical computing algorithms. We will discuss an example where predicting emergency department (ED) deaths due to heart disease is formulated as a binary classification problem. We will try to solve that problem using an artificial neural network (ANN) implemented with Spark MLlib Java API.

In the next section we will explain the problem and express it as a binary classification problem. Then, we will describe how to utilize an ANN to solve that problem. We will also mention various performance metrics to measure correctness of the outcome. Next, we will discuss our approach to select the ANN that solves the problem. We will then give a review of the Java code and finally discuss the findings. 

Problem Statement

The National Center for Health Statistics, part of the U.S. Department of Health and Human Services, regularly publishes results of National Hospital Ambulatory Medical Care Survey (NHAMCS), which includes patient statistics from hospital EDs. We will try to predict patient deaths during ED visits due to heart disease based on various patient characteristics such as age, basic vital measurements and presence of myocardial infarction, i.e. heart attack. 

Classification Problem

In simple terms, classification is the problem of determining which particular class, i.e. category, the output of a system will belong to resulting from a set of inputs. An algorithm that solves a classification problem is called classifier. The particular classification problem we will consider here is as follows. Consider a patient who visits an ED due to a heart problem, now we'll try to predict if that patient will die or survive while in the hospital (ED or hospital ward). 

This can be formulated as a binary classification problem where for a set of input variables there are only two outcomes (hence called binary): either a patient survives or dies in the hospital. Each of those outcomes is a class. Each class is uniquely identified by a label, as summarized below.

Label
Explanation
0
Patient survives, i.e. does not die in the hospital.
1

Patient dies in the hospital (either in ED or in a hospital ward following admission from ED).


Table 1. Label descriptions.

Each input variable is called a feature. For the problem considered here, features are explained below.

Feature
Name
Explanation
1
Age Recode
Age group the patient belongs to: 
    0 = Under 15 years, 1 = 15-24 years, 2 = 25-44 years, 3 = 45-64 years, 
    4 = 65-74 years, 5 = 75-84 years, 6 = 85-95 years, 7 = Above 95 years
2
Temperature
Body temperature being in normal range, defined as between 97-99 F: 
    0 = Normal, 1 = Abnormal
3
Pulse Oximetry (Percent)
Pulse oximetry being in normal range, defined as between 95-100%:
    0 = Normal, 1 = Abnormal
4
Diastolic Blood Pressure
Diastolic blood pressure being in normal range, defined as between 60-80 mm HG:
    0 = Normal, 1 = Abnormal
5
Systolic Blood Pressure
Systolic blood pressure being in normal range, defined as between 90-120 mm HG: 
    0 = Normal, 1 = Abnormal
6
Respiratory Rate
Respiratory rate being in normal range, defined as between 12-25 breaths/minute: 
    0 = Normal, 1 = Abnormal
7
Pulse
Pulse being in normal range, defined as between 60-100 beats/minute:
    0 = Normal, 1 = Abnormal
8
Presence of Heart Attack

Whether patient was diagnosed with heart attack or not:
    0 = Not diagnosed with heart attack, 1 = Diagnosed with heart attack

Table 2. Feature descriptions.

We used NHAMCS Emergency Department public use micro-data files for years 2010, 2011 and 2012, which can be obtained from the official download site. Those are fixed length ASCII files where each row belongs to a unique patient. The features mentioned above correspond to fixed positions in the data file. We enhanced age recode by adding one more age group for patients who are older than 95. (In the original definition of age recode, group 6 covers all patients who are 85 or older.) We eliminated patients who were dead on arrival to ED. In the data file for every patient up to 3 diagnoses are recorded. Because we consider only those patients who have visited the ED due to a heart problem, we required that at least one of the diagnoses have ICD9 code between 410 - 414. (Those codes and their extensions cover all the diagnoses for coronary artery disease.)  Otherwise, we discarded the patient record. The final data file had 915 patients (rows) where 888 survived (class 0) whereas 27 died (class 1).

For Presence of Heart Attack we proceeded as follows. If any of the three diagnoses has ICD9 code 410 or one of its extensions i.e. 410.0 through 410.9 (Acute myocardial infarction) we assumed Presence of Heart Attack has value 1; otherwise, 0. 

Artificial Neural Network

An artificial neural network is a mathematical model with a variety of applications in science and technology. In particular, ANNs can be used to solve the classification problem introduced above. Different types of ANNs exist. Multilayer Perceptron is a particular type of ANN. Spark MLlib library provides an API for a classifier called Multilayer Perceptron Classifier (MLPC) built on Multilayer Perceptron. A schematic representation of an MLPC consisting of multiple inputs and a single output, which will be used in our example, is shown below.

Figure 1. Schematic representation of an MLPC. Classifier's inputs correspond to features and its output corresponds to labels.

Figure 1. Schematic representation of an MLPC. Classifier's inputs correspond to features and its output corresponds to labels.

Each circle represents a neuron, which is a computational unit, i.e mathematical function, that takes inputs (incoming arrows) and produces an output (leaving arrow). The mathematical function in each computational unit has a known structure with various parameters initially undetermined. In our case, the mathematical functions are so that for any input the output is either 0 or 1 (subject to an approximation, which is practically insignificant). The idea is to ‘train’ that ANN based on a set of known inputs (features) and corresponding outputs (labels) to determine the parameters in the mathematical functions. Once the ANN is trained, it is supposed to learn the behavior of the original system so that when a new input (unused in training) is applied the ANN should produce the same output as the original system. 

The so called ‘hidden’ layers are named as such, because number of those layers is not directly linked to number of features or labels. In each layer, there could be a different number of computational units. As the numbers of layers and computational units increase, the number of parameters to determine via training also increases. With more parameters to determine a more ‘flexible’ ANN could be created to better learn the original system. On the other hand, Hastie et. al. states that as the number of computational units increases beyond a certain limit the ANN starts overfitting i.e. it would not generalize well beyond the training data. That reference also states that “Choice of the number of hidden layers is guided by background knowledge and experimentation.” 

Many studies exist for applying ANNs to medical science for diagnostic and prognostic purposes, e.g. stroke diagnosis, and lung cancer detection. In addition to medical science, ANNs have many other application areas such as general decision making functions.

Performance Measurements

After a model is trained, we should be able to measure its performance quantitatively against test data, which is separate from the training data. Then, among different models we select the one that has the best performance against test data. Below we discuss confusion matrix and precision and recall as the performance metrics.

Confusion Matrix

In binary classification, confusion matrix is a 2x2 matrix where each entry of the matrix is a nonnegative integer. Rows 1 and 2 represent the labels 0 and 1, respectively. Columns 1 and 2 represent the predicted labels 0 and 1, respectively. In a particular row, the total of numbers across columns is the number of instances of the particular label in the data set. In a particular column, the total of numbers across rows is the number of times the particular label was predicted by the model. As an example, consider the below confusion matrix.

693
41


In the data set there are 72 (= 69 + 3) instances of label 0 and 5 (= 4 + 1) instances of label 1. That is, 72 patients survived whereas 5 patients died. The model correctly predicted 69 surviving patients, however, it wrongly predicted death for 3 of the surviving patients. On the other hand, the model correctly predicted one death whereas it wrongly predicted survival for 4 patients who actually died.

Precision and Recall

Precision of a label is the # times the label is correctly predicted divided by # times any label is predicted as that particular label. Recall (a.k.a. Sensitivity) of a label is the # times the label is correctly predicted divided by actual instances of the label. Confusion matrix can be used to calculate precision and recall. In above example, precision for label 0 is 69/(69+4) = 0.945 and recall for label 0 is 69/(69+3) = 0.958. 

Both precision and recall are numbers between 0 and 1. As they both approach 1, the model becomes more successful; as either of them approaches 0, the model becomes less successful. In the most ideal case when the model perfectly predicts each label the confusion matrix has 0 in non-diagonal entries.

Note that binary classification is a special case of multi-class classification and the definitions of confusion matrix, precision and recall can be extended to multi-class classification where problem involves more than 2 classes.

Solution Approach

In this section, we summarize our approach to obtain the best mathematical model comprising the MLPC.

  1. Select a candidate set of features.
  2. Define # of hidden layers and computational units in each layer. (Start with a simple model.)
  3. Obtain training and test data sets based on candidate features using k-fold cross validation technique. (There will be k-such pairs.) For each such pair, train one distinct model using training data set and measure its performance against test data set. 
  4. Compare the models and pick the one with best performance.
  5. If result of the best performance model is satisfactory then stop. Otherwise:
    • If improvement is observed, go to step 3 to increase complexity of the model with more computational units and/or layers.
    • If no more improvement is observed go to step 1 to redefine features (start all over again).

Figure 2. Process for model selection.

Figure 2. Process for model selection.

An NHAMCS data file has more than 500 data items including patient demographics, vital measurements, diagnoses, chronic conditions, family history, and statistics about the particular hospital visited by the patient. After eliminating most of those data items based on domain knowledge, we initially identified a set of candidate features and generated a data file in LIBSVM format. This is a commonly used format in machine learning applications. 

We started with a simple model with 2 hidden layers with 5 computational units in each layer. We applied k-fold cross validation with k = 10 to obtain 10 pairs of training and test data sets. Performance metrics did not indicate successful results for any of the models. In particular, models failed to predict dead patients, i.e. recall for label 1 was very close to 0.

We then went back to step 3 to increase model complexity and added more computational units and also added one more hidden layer. When results were still not successful, we concluded that our feature selections were inadequte. We went back to step 1 to review the set of features for possible simplification. (While solving classification problems using ANN, irrelevant features i.e. redundant data could result in poor predictive and computational performance as discussed by O’ Dea et. al..) We removed some features in step 1 and cycled through steps 2 - 5 again to finalize the feature set in Table 2 above and obtained an ANN with 2 hidden layers consisting of 28 and 25 computational units.

Code Review

Our demo application to illustrate how to use Spark API for MLPC (i.e. the classifier based on ANN) starts as follows:

  • Initialize Spark configuration & context.
  • Initialize a SQL context, which is a basis for structured data consisting of rows and columns; required to run MLPC.
public class MultilayerPerceptronClassifierDemo {

    public static void main(String[] args) {
        // Set application name
        String appName = "MultilayerPerceptronClassifier";
        // Initialize Spark configuration & context
        SparkConf conf = new SparkConf().setAppName(appName)
                .setMaster("local[1]").set("spark.executor.memory", "1g");
        SparkContext sc = new SparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
  • Next, randomly split data for 10-fold cross validation.
  • Repeat in a loop 10 times the following: (i) obtain training and test data sets (ii) train model and measure model’s performance.
  • Finally, stop the Spark context. This ends the main program.
        // Load training and test data file from Hadoop and parse.
        String path = "hdfs://localhost:9000/user/konur/ED2010_2011_2012_SVM.txt";
        JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path)
                .toJavaRDD();

        // Obtain 10 sets of training and test data. 12345 is the seed used to randomly split data.
        Tuple2<RDD<LabeledPoint>,RDD<LabeledPoint>>[] myTuple = MLUtils.kFold(data.rdd(), 10, 12345, data.classTag());

        // Train/validate the algorithm once for each set.
        for(int i = 0; i < myTuple.length; i++){
            JavaRDD<LabeledPoint> trainingData = (new JavaRDD<LabeledPoint>(myTuple[i]._1,data.classTag())).cache();
            JavaRDD<LabeledPoint> testData = new JavaRDD<LabeledPoint>(myTuple[i]._2,data.classTag());
            kRun(trainingData,testData,sqlContext);
        }
        sc.stop();
    }

The helper program kRun starts with preparing the data structures for training and test. Then, structure of the MLPC is defined.

    private static final void kRun(JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testData, SQLContext sqlContext){
        DataFrame train = sqlContext.createDataFrame(trainingData, LabeledPoint.class);
        DataFrame test = sqlContext.createDataFrame(testData, LabeledPoint.class);
        // Input consists of 8 features; two hidden layers consist of 28, 25 computational units respectively;
        // Output is binary
        int[] layers = new int[] {8,  28, 25, 2};

We then define the trainer and obtain the trained model.

        // Define the trainer.
        MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
          .setLayers(layers)
          .setBlockSize(128)
          .setSeed(1234L)
          .setMaxIter(150);
        // Obtain the trained model
        MultilayerPerceptronClassificationModel model = trainer.fit(train);

At this point, we have obtained our model. Next, we will apply test data to model and obtain the performance metrics for test data. This completes the kRun method.

        // Apply test data to model and obtain the output
        DataFrame testResult = model.transform(test);
        // Display performance metrics for the output
        displayConfusionMatrix(testResult.collect());
    }

Let us now review the helper method displayConfusionMatrix that calculates and displays the performance metrics. That method starts with various variable definitions.

    private static final void displayConfusionMatrix(Row[] rows){
        // #times label 0 correctly predicted
        int correctlyPredicted0 = 0;

        // #times label 1 correctly predicted
        int correctlyPredicted1 = 0;

        // #times label 1 wrongly predicted as label 0
        int wronglyPredicted0 = 0;

        // #times label 0 wrongly predicted as label 1
        int wronglyPredicted1 = 0;


The output from transform method is so that each row corresponds to a particular test data row where the 1st and 2nd columns correspond to the actual label and its predicted value, respectively. We iterate through all the rows and increment the corresponding variable.


        for(int i=0; i < rows.length; i++){
            Row row = rows[i];
            double label = row.getDouble(1);
            double prediction = row.getDouble(2);

            if(label == 0.0){
                if(prediction == 0.0){
                    correctlyPredicted0++;
                }else{
                    wronglyPredicted1++;
                }
            }else{
                if(prediction == 1.0){
                    correctlyPredicted1++;
                }else{
                    wronglyPredicted0++;
                }
            }
        }

Finally display the confusion matrix and calculate precision and recall for label 0 and 1.

        float fcorrectlyPredicted0 = correctlyPredicted0 * 1.0f;
        float fcorrectlyPredicted1 = correctlyPredicted1 * 1.0f;
        float fwronglyPredicted0 = wronglyPredicted0 * 1.0f;
        float fwronglyPredicted1 = wronglyPredicted1 * 1.0f;

        System.out.println("************");
        System.out.println(correctlyPredicted0 + "      " + wronglyPredicted1);
        System.out.println(wronglyPredicted0 + "      " + correctlyPredicted1);

        System.out.println("Class 0 precision: " + ((fcorrectlyPredicted0 == 0.0f)?0.0:(fcorrectlyPredicted0 / (fcorrectlyPredicted0 + fwronglyPredicted0))));
        System.out.println("Class 0 recall: " + ((fcorrectlyPredicted0 == 0.0f)?0.0:(fcorrectlyPredicted0 / (fcorrectlyPredicted0 + fwronglyPredicted1))));

        System.out.println("Class 1 precision: " + ((fcorrectlyPredicted1 == 0.0f)?0.0:(fcorrectlyPredicted1 / (fcorrectlyPredicted1 + fwronglyPredicted1))));
        System.out.println("Class 1 recall: " + ((fcorrectlyPredicted1 == 0.0f)?0.0:(fcorrectlyPredicted1 / (fcorrectlyPredicted1 + fwronglyPredicted0))));
        System.out.println("************");
    }

We’ve run the above code in a Spark server with a single-node Hadoop installation, version 2.7.1 with Spark API version 1.6.1, which was latest at the time of writing the article. Complete Java code can be downloaded from https://github.com/kunyelio/Spark-MLPC

Discussion of Results

Let us first show the confusion matrix, precision and recall for test data of a model that has two hidden layers each consisting of 5 computational units.

702
41
  • Class 0 precision: 0.946
  • Class 0 recall: 0.972
  • Class 1 precision: 0.333
  • Class 1 recall: 0.2

Although model has reasonable performance for class 0 (patients survived), it performed poorly for class 1 (patients died). 

Next, let us show the confusion matrix, precision and recall for test data of the most successful model. It has two hidden layers consisting of 28, 25 computational units, respectively.

890
01
  • Class 0 precision: 1.0
  • Class 0 recall: 1.0
  • Class 1 precision: 1.0
  • Class 1 recall: 1.0

The model achieves perfect performance, predicting all labels correctly. We observe that performance has improved by increasing number of computational units.

Conclusions

In this article we used an artificial neural network (ANN) from Spark machine learning library as a classifier to predict emergency department deaths due to heart disease. We discussed a high-level process for feature selection, choosing number of hidden layers of the network and number of computational units. Based on that process, we found a model that achieved very good performance on test data. We observed that Spark MLlib API is simple and easy to use for training the classifier and calculating its performance metrics. In reference to Hastie et. al, we have some final comments.

  • It is recommended that features be balanced in terms of magnitude when using ANN as a classifier.
    • Indeed, in our case, all features except age recode were binary. Age recode accepted values from a discrete set of 8 values, an acceptable disparity.
  • Typically the number of computational units is between 5 - 100 “... with the number increasing with the number of inputs and number of training cases.” 
    • In our case, the number of computational units is 53 for the most successful model.
  • As number of computational units increase, it also takes more computational time to train the model.
    • In our case, for the initial simple model with 2 hidden layers with 5 computational units in each layer, training the model i.e. 
      MultilayerPerceptronClassificationModel model = trainer.fit(train);      
      took approximately 4 seconds on average for each split. With the final model of 2 hidden layers with 28 and 25 computational units, respectively, it took 6 seconds. As expected, we observed increased computational time. (Because we used a Spark server with a single-node Hadoop installation, computational times should not be generalized to real scenarios. In cluster mode, expect smaller computational times than single node.)  

Hortonworks DataFlow is an integrated platform that makes data ingestion fast, easy, and secure. Download the white paper now.  Brought to you in partnership with Hortonworks

Topics:
big data ,big data analytics ,apache spark ,machine learning ,medical research ,neural networks ,classification models ,classification

Opinions expressed by DZone contributors are their own.

The best of DZone straight to your inbox.

SEE AN EXAMPLE
Please provide a valid email address.

Thanks for subscribing!

Awesome! Check your inbox to verify your email so you can start receiving the latest in tech news and resources.
Subscribe

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}