Over a million developers have joined DZone.

Java Email Spam Classifier Application With Spark

DZone 's Guide to

Java Email Spam Classifier Application With Spark

Learn how to develop and use your own machine learning-based email spam classification system — because who likes having spam?

· AI Zone ·
Free Resource

In this post, we are going to develop an application for the purpose of detecting spam emails. The algorithm that will be used is logistic regression with implementation from SPARK MLib. No deep knowledge on the field is required, as the topics are described from as high-level of a perspective as possible. The full working code will be provided together with a running application for further experiments on your choice of emails.

Logistic Regression

Logistic regression is an algorithm used for classification problems. In classification problems, we are given a lot of labeled data (spam and not spam) and when a new example is coming, we want to know which category it belongs to. Since it is a machine learning algorithm, logistic regression is trained with labeled data and based on the training, it gives is predictions about new coming examples.


In general, when a lot of data is available and we need to detect in which category an example belongs to, logistic regression can be used (even if the results aren't always satisfactory).


Logistic regression can be used when, for example, analyzing millions of patients' health conditions to predict if a patient will have a myocardial infarction. The same logic can be applied to predict if a patient will have particular cancer, be affected by depression, and so on. In this application, we have a considerable amount of data, so logistic regression usually gives good hints.

Image Categorization

Based on image density colors, we can categorize whether, let's say, an image contains a human or contains a car. Also, since it's a categorizing problem, we may also use logistic regression to detect if a picture has characters or even to detect handwriting.

Message and Email Spam Classification

One of the most common applications of logistic regression is classifying email spam. In this application, the algorithm determines whether an incoming email or message is spam. When a non-personalized algorithm is built, a lot of data is needed. Personalized filters usually perform better because the spam classifier depends on some certain degree to the person's interest and background.

How It Works

We have a lot of labeled examples and want to train our algorithm to be smart enough to say whether new examples are part of one or the other category. For simplification, we are going to refer first to only binary classification (1 or 0). The algorithm also scales easily to multi-classification.


Usually, we have multidimensional data or data with many features. Each of these features somehow contributes to the final decision of which category a new example belongs to. For example, in a cancer classification problem, we can have features like age, smoking or not, weight, height, family genome, and so onEach of these features contributes in a way to the final category decision. Features do not contribute equally but rather have different impacts in the determining the final state. For example, weight has a lower impact than family genome in cancer prediction. In logistic regression, that is exactly what we are trying to find out: the weights/impact of the features of our data. Once we have a lot of data examples, we can determine the weight of each feature, and when new examples come, we use the weights to see how the example is categorized. In the cancer prediction example, we can write this like below:

More formally:

n = number of examples

k = number of features

θj = weight for feature j

Xji = the i-th example X with feature j

Model Representation

In order to sort our data into categories, we need a function (hypothesis) that, based on examples, values and features, can put data into one of the two categories. The function we use for this is called the Sigmoid function and graphically, it looks like below:

As we can see, when values on the X-axis are positive, the Sigmoid function values tend to go toward 1 and when values on the X-axis are negative, it tends to go toward 0. Basically, we have a model to represent two categories and mathematically, the function looks like below:

Z is the function explained above under "Insight."

To get discrete values (1 or 0), we can say that when a function value (Y-axis) is greater than 0.5, we classify it as 1, and when a function value (Y-axis) is smaller than 0.5, we classify it as 0, as described below:

  • Y > 0.5 = 1 (spam/cancer)

  • Y < 0.5 = 0 (not spam/not cancer)

  • Z > 0 = 1 (spam/cancer)

  • Z < 0 = 0 (not spam/not cancer)

Cost Function

We don't want to find just any weights but rather the best weights we can have with the actual data. To find the best weights, we need another function that calculates how good a solution is for particular weights we found. With this function, we can compare different solutions with different weights and find the best one. This function is called the cost function. It compares the hypothesis (Sigmoid) function value with the real data value. Since the data we use for training are labeled (spam or not spam), we compare the hypothesis (Sigmoid) prediction with the actual value, which we know for sure. We want the difference between the hypothesis and real value to be as small as possible; ideally, we want the cost function to be zero. More formally, the cost function is defined as:

...where yi is the real value/category, like spam/not spam or 1/0, and h(x) is the hypothesis. 

Basically, this equation calculates how good (on average) our prediction is compared to real labeled data (y). Because we have two cases (1 and 0), we have two Hs (hypotheses): h1 andh0. We apply log to the hypothesis so that the function is convex and it's safer to find the global minimum.

Let's look at h1, which is the hypothesis in relation to the cost function for category 1.

We applied log to our hypothesis instead of using it directly because we want to achieve a relationship such that when the hypothesis is close to one, the cost function goes to zero. Remember that we want our cost function to be zero so that there is no difference between the hypothesis prediction and the labeled data. If the hypothesis is going to predict 0, our cost function grows large, so we know that this is not an example belonging to category 1; and if the hypothesis is going to predict 1, the cost function goes to 0, signaling that the example belongs to category 1.

Let's look at h2, which is the hypothesis in relation to the cost function for category 0.

In this case, we applied log again, but in a way that causes the cost function to go to zero when the hypothesis is going also to predict zero. If the hypothesis is going to predict 1, our cost function grows large, so we know that this is not an example belonging to category 0; and if the hypothesis is going to predict 0, the cost function goes to 0, signaling that the example belongs to category 0.

Now, we have two cost functions and we need to combine them in one. After this, the equation becomes a bit messy, but in principle, it's just a merge of the two cost functions we explained above:

Notice that the first term is the cost function for h1 and the second term is the cost function for h0. So, if = 1, then the second term is eliminated, and if = 0, then the first term is eliminated.  

Minimize Cost Function

As we saw above, we want our cost function to be zero so that our prediction is as close as possible to the real value (labeled). Fortunately, there is already an algorithm to minimize the cost function: gradient descent.Once we have the cost function (which basically compares our hypothesis to real values), we can put our weights (θ) in order to lower the cost function as much as possible. First, we pick up random values of θ just to have some values. Then, we calculate cost function. Depending on the results, we can decrease or increase our θ values so that the cost function is optimized to zero. We repeat this until the cost function is almost zero (0.0001) or is not improving much from iteration to iteration.

Gradient descent is doing exactly this in principle; it just a derivative of the cost function to decide whether to decrease or increase θ values. It also uses a coefficient α to define how much to change the θ values. Changing θ values too much (bigα) can make gradient descent fail in optimizing the cost function to zero, since a big increase may overcome the real value or go far the from wanted value. While having a small change of θ (small α) means we are safe, the algorithm needs a lot of time to go to the minimum value of the cost function (almost zero) since we are progressing too slowly toward the wanted or real value (for a more visual explanation, see here). More formally, we have:

The term on the right is the derivative of the cost function (changes only by multiples of X for feature k). Since our data are multidimensional (k features), we do this for each feature weight (θk). 

Algorithm Execution

Let's look at preparing the data, transforming the data, the execution, and the results.

Prepare Data

Before executing the data, we need to do some data prepossessing to clean up not-useful information. The main idea for the data reprocessing is from this Coursera assignment. We do the following:

  • Lower-casing: The entire email is converted into lower case so that capitalization is ignored (i.e., IndIcaTE is treated the same as Indicate).
  • Stripping HTML: All HTML tags are removed from the emails. Many emails often come with HTML formatting; we remove all the HTML tags so that only the content remains.
  • Normalizing URLs: All URLs are replaced with the text “XURLX”.
  •  Normalizing email addresses: All email addresses are replaced with the text “XEMAILX”.
  • Normalizing numbers: All numbers are replaced with the text “XNUMBERX”.
  • Normalizing dollars: All dollar signs ($) are replaced with the text “XMONEYX”.
  • Word stemming: Words are reduced to their stemmed form. For example, “discount,” “discounts,” “discounted,” and “discounting” are all replaced with “discount.” Sometimes, the Stemmer actually strips additional characters from the end, so “include,” “includes,” “included,” and “including” are all replaced with “includ.”
  • Removal of non-words: Non-words and punctuation are removed. All white spaces (i.e. tabs, newlines, spaces) are all trimmed to a single space character.

The code implementation will look like this:

private List<String> filesToWords(String fileName) throws Exception {
    URI uri = this.getClass().getResource("/" + fileName).toURI();
    Path start = getPath(uri);
    List<String> collect = Files.walk(start).parallel()
            .flatMap(file -> {
                try {

                    return Stream.of(new String(Files.readAllBytes(file)).toLowerCase());
                } catch (IOException e) {
                return null;

    return collect.stream().parallel().flatMap(e -> tokenizeIntoWords(prepareEmail(e)).stream()).collect(Collectors.toList());
private String prepareEmail(String email) {
    int beginIndex = email.indexOf("\n\n");
    String withoutHeader = email;
    if (beginIndex > 0) {
        withoutHeader = email.substring(beginIndex, email.length());
    String tagsRemoved = withoutHeader.replaceAll("<[^<>]+>", "");
    String numberedReplaced = tagsRemoved.replaceAll("[0-9]+", "XNUMBERX ");
    String urlReplaced = numberedReplaced.replaceAll("(http|https)://[^\\s]*", "XURLX ");
    String emailReplaced = urlReplaced.replaceAll("[^\\s]+@[^\\s]+", "XEMAILX ");
    String dollarReplaced = emailReplaced.replaceAll("[$]+", "XMONEYX ");
    return dollarReplaced;

private List<String> tokenizeIntoWords(String dollarReplaced) {
    String delim = "[' @$/#.-:&*+=[]?!(){},''\\\">_<;%'\t\n\r\f";
    StringTokenizer stringTokenizer = new StringTokenizer(dollarReplaced, delim);
    List<String> wordsList = new ArrayList<>();
    while (stringTokenizer.hasMoreElements()) {
        String word = (String) stringTokenizer.nextElement();
        String nonAlphaNumericRemoved = word.replaceAll("[^a-zA-Z0-9]", "");
        PorterStemmer stemmer = new PorterStemmer();
        String stemmed = stemmer.getCurrent();
    return wordsList;

Transform Data

Once the emails are prepared, we need to transform the data into a structure that the algorithm understands, like matrices and features.

The first step is to build a "spam vocabulary" by reading all spam email's words and counting them. For example, we count how many times "transaction," "XMONEYX, "finance," "win," and "free" are used. Then, pick up the 10 (featureSize) most frequent words. At this point, we have a map of size 10 (featureSize) in which the key is the word and the value is the index from 0 to 9.999. This will serve as a reference for possible spam words. See the code below:

public Map<String, Integer> createVocabulary() throws Exception {
    String first = "allInOneSpamBase/spam";
    String second = "allInOneSpamBase/spam_2";
    List<String> collect1 = filesToWords(first);
    List<String> collect2 = filesToWords(second);

    ArrayList<String> all = new ArrayList<>(collect1);
    HashMap<String, Integer> countWords = countWords(all);

    List<Map.Entry<String, Integer>> sortedVocabulary = countWords.entrySet().stream().parallel().sorted((o1, o2) -> o2.getValue().compareTo(o1.getValue())).collect(Collectors.toList());
    final int[] index = {0};
    return sortedVocabulary.stream().limit(featureSIze).collect(Collectors.toMap(e -> e.getKey(), e -> index[0]++));
HashMap<String, Integer> countWords(List<String> all) {
    HashMap<String, Integer> countWords = new HashMap<>();
    for (String s : all) {
        if (countWords.get(s) == null) {
            countWords.put(s, 1);
        } else {
            countWords.put(s, countWords.get(s) + 1);
    return countWords;

The next step is to count the word frequency for these words in both our spam and non-spam emails. Then, we look up each of those words in the spam vocabulary to see if it is there. If it is (meaning the email has a possible spam word), we put this word in the same index contained in the spam vocabulary map and as the value, we put the frequency. In the end, we build a matrix Nx10.000 where N is the number of emails considered and 10.000 is the vector containing the frequency of spam vocabulary map words in the emails (if a spam word is not found in the email, we put 0).

For example, let's say we have a spam vocabulary like below:

  • aa

  • how

  • bil

  • anyon

  • know

  • zero

  • zip

And also an email like below in prepossessed form:

anyon know how much it cost to host a web portal well it depend on how mani visitor your expect thi can be anywher from less than number buck a month to a coupl of dollarnumb you should checkout XURLX or perhap amazon ecnumb if your run someth big to unsubscrib yourself from thi mail list send an email to XEMAILX

After the transformation, we will have:

0 2 0 1 1 1 0 0

So we have 0 aa, 2 how, 0 abil, 1 anyon, 1 know, 0 zero, 0 zip. This is a1x7 matrix since we had one email and spam vocabulary of 7 words. The code looks like below:

private Vector transformToFeatureVector(Email email, Map<String, Integer> vocabulary) {
    List<String> words = email.getWords();
    HashMap<String, Integer> countWords = prepareData.countWords(words);
    double[] features = new double[featureSIze];//featureSIze==10.000
    for (Map.Entry<String, Integer> word : countWords.entrySet()) {
        Integer index = vocabulary.get(word.getKey());//see if it is in //spam vocabulary 
        if (index != null) {
//put frequency the same index as the vocabulary
            features[index] = word.getValue();
    return Vectors.dense(features);

Execute and Results

The application can be downloaded and executed without any knowledge of Java, though Java has to be installed on your computer. Feel free to test the algorithm with your own emails.

We can run the application from the source by simply executing the RUN class. Or, if you do not want to open it with IDE, just run mvn clean install exec:java.

After that, you should see something like this:

First, train the algorithm by clicking Train with LR SGD or Train with LR LBFGS. This may take one to two minutes. After finishing, a pop-up will show the precision achieved. Don't worry about SGD versus LBFGS — they are just different ways of minimizing the cost function and will give almost same results. After that, copy and paste an email of your choice into the white area and hit Test. After that, a pop window will show the algorithm's prediction.

The precision achieved during my execution was approximately 97%, using a random 80% of the data for training and 20% for testing. There's no cross-validation test — just a training and test (for accuracy) set were used in this example. To learn more about dividing the data, see here.

The code for training the algorithm is fairly simple:

public MulticlassMetrics execute() throws Exception {
    vocabulary = prepareData.createVocabulary();
    List<LabeledPoint> labeledPoints = convertToLabelPoints();
    sparkContext = createSparkContext();
    JavaRDD<LabeledPoint> labeledPointJavaRDD = sparkContext.parallelize(labeledPoints);
    JavaRDD<LabeledPoint>[] splits = labeledPointJavaRDD.randomSplit(new double[]{0.8, 0.2}, 11L);
    JavaRDD<LabeledPoint> training = splits[0].cache();
    JavaRDD<LabeledPoint> test = splits[1];

    linearModel = model.run(training.rdd());//training with 80% data

//testing with 20% data
    JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
            (Function<LabeledPoint, Tuple2<Object, Object>>) p -> {
                Double prediction = linearModel.predict(p.features());
                return new Tuple2<>(prediction, p.label());

    return new MulticlassMetrics(predictionAndLabels.rdd());

And that's it!

machine learning ,spark mlib ,logistic regression ,algorithm ,ai ,tutorial ,java ,spark

Published at DZone with permission of

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}