Heart Disease Prediction Using Machine Learning and Big Data Stack

DZone 's Guide to

Heart Disease Prediction Using Machine Learning and Big Data Stack

Explore the prediction of the existence of heart disease by using standard ML algorithms and a Big Data toolset like Apache Spark, parquet, Spark mllib, and Spark SQL.

· AI Zone ·
Free Resource

The combination of big data and machine learning is a revolutionary technology that can make a great impact on any industry if used in a proper way. In the field of healthcare it has great usage in cases like early disease detection, finding signs of early breakouts of epidemics, using clustering to figure out regions of epidemics (e.g. like 'Zika' prone areas), or finding the best air quality zones in countries with high air pollution.

In this article, I have tried to explore the prediction of the existence of heart disease by using standard machine learning algorithms, and the big data toolset like Apache Spark, parquet, Spark mllib, and Spark SQL.

Source Code

The source code of this article is available on GitHub here. Also, you can check out the entire eclipse project from here.

Dataset Used

The heart disease dataset is a very well studied dataset by researchers in machine learning and is freely available at the UCI machine learning dataset repository here. Though there are 4 datasets in this, I have used the Cleveland dataset that has 14 main features. The features or attributes are:

  • age - age in years
  • sex - sex (1 = male; 0 = female)
  • cp: chest pain type
       value '1': typical angina
       value '2': atypical angina
       value '3': non-anginal pain
       value '4': asymptomatic  
  • trestbpss: resting blood pressure (in mm Hg on admission to the hospital)
  • chol: serum cholestoral in mg/dl
  • fbs: (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)
  • restecg: resting electrocardiographic results
    value 0: normal
    value 1: having ST-T wave abnormality (T wave inversions and/or 
                                           ST elevation or depression of > 0.05 mV)
    value 2: showing probable or definite left ventricular hypertrophy by Estes criteria 
  • thalach: maximum heart rate achieved
  • exang: exercise induced angina (1 = yes; 0 = no)
  • oldpeak: ST depression induced by exercise relative to rest
  • slope: the slope of the peak exercise ST segment
    value 1: upsloping
    value 2: flat
  • ca: number of major vessels (0-3) colored by flourosopy
  • thal: 3 = normal; 6 = fixed defect; 7 = reversable defect
  • num: diagnosis of heart disease (angiographic disease status)
    value 0: < 50% diameter narrowing  (Means 'No Disease')
    value 1: > 50% diameter narrowing  (Means 'Disease is Present')

Technology Used

  • Apache Spark: Apache Spark is one of the toolsets from the big data stack and is essentially the big brother of the old map reduce technology. It is much faster in performance and also it is much easier to code in Apache Spark as compared to mapreduce. RDD (the resilient distributed dataset), which a lot of developers use as a normal variable, is the crux of the whole Apache Spark piece, but behind the scenes, it nicely handles all the distributed computing work. Spark comes with other cool packages like Spark streaming, Spark sql (which I would use in this article to analyze the dataset), spark mllib (which I would use to apply the machine learning piece). The documentation for Spark from the Spark page is excellent and can be found here.
  • Spark SQL: SQL-like API from Spark that supports DataFrames (almost similar to Pandas library from Python but this one runs over a full distributed dataset and hence does not have all the similar functions).
  • Parquet: Parquet is a columnar file format. The raw data files are parsed and stored in parquet format. This helps in speeding up the aggregation queries a lot. A columnar format helps in choosing only the columns that are needed and hence reduces disk I/O tremendously.
  • Spark MLLib: Machine Learning library from Spark. The algorithms in this library are optimized to run over a distributed dataset. This is the main difference between this library and the other popular libraries like SciKit that run in a single process.
  • HDFS: for storing the raw files, storing the generated model and storing the results.


Model Generation and Storage Layer

Image title

As shown in the image above the raw files are either pulled into HDFS or they are pushed by some programs directly into HDFS. The file or data can also be received via Kafka topics and read using spark streaming. As for this article and the sample code in Github, I am assuming that the raw files reside in HDFS.

The files are read via the Spark Program in Java (it can be in python or scala too).

The files contain data that has to be adapted into the format that the model requires. The model requires all numbers. Some of the datapoints have null or no values and they are replaced by a large value like '99.0' that has no specific meaning except it only helps in passing the null validation. Also, the last 'num' parameters is converted to digits of either '1' or '0' based on whether the user has or doesn't have heart disease. Thereby any value in the last 'num' field which is greater than '1' is converted to '1' and it means that the heart disease exists.

The data files are now read into an RDD.

For this dataset, I have used Naive Bayes algorithm (the same algorithm that is used in Spam Filters). Using the machine learning library from Spark (mllib), the algorithm is now trained with the data from the dataset. Note: Decision Tree algorithm might also give good results in this case.

After the algorithm is trained the model is now stored into an external storage on the hdfs for future use for making predictions on the test data.

Here is a snapshot of the above code:

SparkConfAndCtxBuilder ctxBuilder = new SparkConfAndCtxBuilder();
JavaSparkContext jctx = ctxBuilder.loadSimpleSparkContext("Heart Disease Detection App", "local");

//Read the data into an RDD, data is in the form of strings per line
JavaRDD<String> dsLines = jctx.textFile(trainDataLoc);
        // Parse these lines using the adapter class and apply it to each row of data.
        // Now the data is available in the format as required by the Model
JavaRDD<LabeledPoint> _modelTrainData = dsLines.map(new DataToModelAdapterMapper());

    //The model that is used, is now trained with Data
        //You can replace the code below to try with other models e.g Decision Trees. and
        // compare the accuracy of the results obtained
NaiveBayesModel _model = NaiveBayes.train(_modelTrainData.rdd());
_model.save(jctx.sc(), modelStorageLoc);


Also snipped of the code from the mapper class that is applied above to each data row is:

public LabeledPoint call(String dataRow) throws Exception {
    //Clean empty data points and apply a high value, to avoid any unnecessary null pointers
    String newLine = dataRow.replaceAll("\\?", "99.0");
    String[] tokens = newLine.split(",");

    System.out.println("tokens count : " + tokens.length);
    // last token has the actual predicted value used in training the model
    Integer lastToken = Integer.parseInt(tokens[13]);

    double[] featuresDblArr = new double[13];
    for(int i = 0; i < 13; i++) {
        featuresDblArr[i] = Double.parseDouble(tokens[i]);
    // building the feature vector here
    Vector featuresVector = new DenseVector(featuresDblArr);

    Double classValue = 0.0;
    if(lastToken.intValue() > 0 ) classValue = 1.0;

    LabeledPoint _lp = new LabeledPoint(classValue, featuresVector);

    return _lp;

Data Analysis Layer

This layer is used for the analysis of the training data for queries like min age of the person with the disease, total number of women vs men with the disease, which parameter is almost always present when the disease occurs, no. of people with symptoms but don't have the disease, etc.

For running data analysis on the training data, first, load the full data (clean data) into an rdd using textfile.

Now save this rdd to external storage in parquet format.

From another program load the data into a dataframe from this parquet storage.

You can see a snippet of the code below for full source code refer to the code here.

String schemaString = "age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num";
List<StructField> fields = new ArrayList<>();
for (String fieldName : schemaString.split(" ")) {
    fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true));

StructType schema = DataTypes.createStructType(fields);
JavaRDD<Row> rowRdd = rows.map(new Function<String, Row>() {

    public Row call(String record) throws Exception {
        String[] fields = record.split(",");
        return RowFactory.create(fields[0],fields[1],fields[2],fields[3],fields[4],fields[5],fields[6],fields[7],fields[8],fields[9],fields[10],fields[11],fields[12],fields[13]);

DataFrame df = sqlCtx.createDataFrame(rowRdd, schema);

DataFrame results = sqlCtx.sql("select min(age) from heartDisData");

JavaRDD<String> jrdd = results.javaRDD().map(new Function<Row, String>() {

    public String call(Row arg0) throws Exception {
        return arg0.toString();

List<String> rstList = jrdd.collect();

for (String rStr : rstList) {
    System.out.println(" Minimum Age : " + rStr);

Disease Prediction Layer (Refer to the code in Github)

Image title

Now load the test data into an RDD using Apache Spark.

Clean and adapt the test data to the model.

Load the model from storage using spark mllib.

Use the model object to predict the presence of disease. For example:

NaiveBayesModel _model = NaiveBayesModel.load(<Spark Context>, <Model Storage Location>);

A snippet of the code is shown below, refer to the GitHub source location for full code.

SparkConfAndCtxBuilder ctxBuilder = new SparkConfAndCtxBuilder();
JavaSparkContext jctx = ctxBuilder.loadSimpleSparkContext("Heart Disease Detection App", "local");

JavaRDD<String> dsLines = jctx.textFile(testDataLoc);
JavaRDD<Vector> fRdd = dsLines.map(new TestDataToFeatureVectorMapper());

NaiveBayesModel _model = NaiveBayesModel.load(jctx.sc(), modelStorageLoc);

JavaRDD<Double> predictedResults = _model.predict(fRdd);
List<Double> prl = predictedResults.collect();
for (Double pr : prl) {
    System.out.println("Predicted Value : " + pr);

Problem With the Above Design

The most important issue with any disease prediction system is accuracy. A false negative in the result can be a dangerous prediction that can get a disease unnoticed.

Deep learning has evolved to give much better predictions than regular machine learning algorithms. In a future article, I will try exploring doing the same disease prediction via deep learning neural networks.


Using tools like Apache Spark and it's machine learning library we were easily able to load a heart disease dataset (from UCI) and trained regular machine learning model. This model was later used to predict the existence of heart disease on test samples of data.

The source code for this is available on GitHub here.

ai, apache spark, big data, big data analytics, hadoop, hdfs, machine learning, parquet

Published at DZone with permission of Rajat Mehta . See the original article here.

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}