Over a million developers have joined DZone.

Prediction API: Machine Learning from Google

· Big Data Zone


One of the exciting APIs among the 50+ APIs offered by Google is the Prediction API. It provides pattern matching and machine learning capabilities like recommendations or categorization. The notion is similar to the machine learning capabilities that we can see in other solutions (e.g. in Apache Mahout): we can train the system with a set of training data and then the applications based on Prediction API can recommend ("predict") what products the user might like or  they can categories spams, etc. In this post we go through an example how to categorize SMS messages - whether they are spams or valuable texts ("hams").

Using Prediction API

In order to be able to use Prediction API, the service needs to be enabled via Google API console. To upload training data, Prediction API also requires Google Cloud Storage. The dataset  used in this post is from UCI Machine Learning Repository.  UCI Machine Learning repository has 235 datasets publicly available, this post is based on SMS Spam Collections dataset.

To upload the training data first we need to create a bucket in Google Cloud Storage. From Google API console we need to click on Google Cloud Storage and then on Google Cloud Storage Manager: This will open a webpage whe we can create new buckets and upload or delete files. GoogleStorage2

The UCI SMS Spam Collection file is not suitable as is for Prediction API, it needs to be converted into the following format (the categories - ham/spam - need to be quoted as well as the SMS text):

"ham" "Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..."


Google Prediction API offers a handful of commands that can be invoked via REST interface. The simplest way of testing Prediction API is to use Prediction API explorer.GooglePrediction1

Once the training data is available on Google Cloud Storage, we can start training the machine learning system behind Prediction API. To begin training our model, we need to run prediction.trainedmodels.insert. All commands require authentication, it is based on OAuth 2.0 standard.


In the insert menu we need to specify the fields that we want to be included in the response.  In the request body we need to define an id (this will be used as a reference to the model in the commands used later on), a storageDataLocation where we have the training data uploaded (the Google Cloud Storage path) and the modelType (could be regression or classification, for spam filtering it is classification):


The training runs for a while, we can check the status using prediction.trainedmodels.get command. The status field is going to be RUNNING and then will be changed to DONE, once the training is finished.


Now we are ready to run our test against the machine learning system and it is going to classify whether the given text is spam or ham. The Prediction API command for this action is prediction.trainedmodels.predict. In the id field we have to refer to the id that we defined for the  prediction.trainedmodels.insert command (bighadoop-00001) and we also need to specify the request body - input will be csvInstance and then we enter the text that we want to get categorized (e.g. "Free entry")


The system then returns with the category (spam) and the score (0.822158 for spam, 0.177842 for ham):


Google Prediction API libraries

Google also offers a featured sample application that includes all the code required to run it on Google App Engine. It is called Try-Prediction and the code is written in Python and also in Java. The application can be tested at http://try-prediction.appspot.com. For instance, if we enter a quote for the Language Detection model from Niels Bohr: "Prediction is very difficult, especially if it's about the future.", it will return that it is likely to be an English text (54,4%). TryPrediction

The key part of the Python code is in predict.py: 

class PredictAPI(webapp.RequestHandler):
  '''This class handles Ajax prediction requests, i.e. not user initiated
     web sessions but remote procedure calls initiated from the Javascript
     client code running the browser.

  def get(self):
      # Read server-side OAuth 2.0 credentials from datastore and
      # raise an exception if credentials not found.
      credentials = StorageByKeyName(CredentialsModel, USER_AGENT, 
      if not credentials or credentials.invalid:
        raise Exception('missing OAuth 2.0 credentials')

      # Authorize HTTP session with server credentials and obtain  
      # access to prediction API client library.
      http = credentials.authorize(httplib2.Http())
      service = build('prediction', 'v1.4', http=http)
      papi = service.trainedmodels()

      # Read and parse JSON model description data.
      models = parse_json_file(MODELS_FILE)

      # Get reference to user's selected model.
      model_name = self.request.get('model')
      model = models[model_name]

      # Build prediction data (csvInstance) dynamically based on form input.
      vals = []
      for field in model['fields']:
        label = field['label']
        val = str(self.request.get(label))
      body = {'input' : {'csvInstance' : vals }}
      logging.info('model:' + model_name + ' body:' + str(body))

      # Make a prediction and return JSON results to Javascript client.
      ret = papi.predict(id=model['model_id'], body=body).execute()

    except Exception, err:
      # Capture any API errors here and pass response from API back to
      # Javascript client embedded in a special error indication tag.
      err_str = str(err)
      if err_str[0:len(ERR_TAG)] != ERR_TAG:
        err_str = ERR_TAG + err_str + ERR_END
The Java version of Prediction web application is as follows:
public class PredictServlet extends HttpServlet {

  protected void doGet(HttpServletRequest request,
                       HttpServletResponse response) throws ServletException, 
                                                            IOException {
    Entity credentials = null;
    try {
      // Retrieve server credentials from app engine datastore.
      DatastoreService datastore = 
      Key credsKey = KeyFactory.createKey("Credentials", "Credentials");
      credentials = datastore.get(credsKey);
    } catch (EntityNotFoundException ex) {
      // If can't obtain credentials, send exception back to Javascript client.
      response.getWriter().println("exception: " + ex.getMessage());

    // Extract tokens from retrieved credentials.
    AccessTokenResponse tokens = new AccessTokenResponse();
    tokens.accessToken = (String) credentials.getProperty("accessToken");
    tokens.expiresIn = (Long) credentials.getProperty("expiresIn");
    tokens.refreshToken = (String) credentials.getProperty("refreshToken");
    String clientId = (String) credentials.getProperty("clientId");
    String clientSecret = (String) credentials.getProperty("clientSecret");
    tokens.scope = IndexServlet.scope;

    // Set up the HTTP transport and JSON factory
    HttpTransport httpTransport = new NetHttpTransport();
    JsonFactory jsonFactory = new JacksonFactory();

    // Get user requested model, if specified.
    String model_name = request.getParameter("model");

    // Parse model descriptions from models.json file.
    Map models = 

    // Setup reference to user specified model description.
    Map selectedModel = 
      (Map) models.get(model_name);
    // Obtain model id (the name under which model was trained), 
    // and iterate over the model fields, building a list of Strings
    // to pass into the prediction request.
    String modelId = (String) selectedModel.get("model_id");
    List params = new ArrayList();
    List<Map > fields = 
      (List<Map >) selectedModel.get("fields");
    for (Map field : fields) {
      // This loop is populating the input csv values for the prediction call.
      String label = field.get("label");
      String value = request.getParameter(label);

    // Set up OAuth 2.0 access of protected resources using the retrieved
    // refresh and access tokens, automatically refreshing the access token 
    // whenever it expires.
    GoogleAccessProtectedResource requestInitializer = 
      new GoogleAccessProtectedResource(tokens.accessToken, httpTransport, 
                                        jsonFactory, clientId, clientSecret, 

    // Now populate the prediction data, issue the API call and return the
    // JSON results to the Javascript AJAX client.
    Prediction prediction = new Prediction(httpTransport, requestInitializer, 
    Input input = new Input();
    InputInput inputInput = new InputInput();
    Output output = 
      prediction.trainedmodels().predict(modelId, input).execute();
Besides Python and Java support, Google also offers .NET, Objective-C, Ruby, Go, JavaScript, PHP, etc. libraries for Prediction API.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}