Building a GBM Model in H2O With Grid Search and Hyperparameters in Scala

DZone 's Guide to

Building a GBM Model in H2O With Grid Search and Hyperparameters in Scala

Learn how H2O and GBM can be used to perform grid search and optimize hyperparameters and get the code to help you do it!

· AI Zone ·
Free Resource

H2O is an open-source data analytics software, and GBM (gradient boosting machine) can be used for accurate predictive analytics. Together, H2O and GBM can be used to perform grid search and optimize hyperparameters. Here is the full source code for GBM Scala code to perform grid search and hyperparameter optimization using H2O (here is the GitHub code, as well):

import org.apache.spark.SparkFiles
import org.apache.spark.h2o._
import org.apache.spark.examples.h2o._
import org.apache.spark.sql.{DataFrame, SQLContext}
import water.Key
import java.io.File

import water.support.SparkContextSupport.addFiles
import water.support.H2OFrameSupport._

// Create SQL support
implicit val sqlContext = spark.sqlContext
import sqlContext.implicits._

// Start H2O services
val h2oContext = H2OContext.getOrCreate(sc)
import h2oContext._
import h2oContext.implicits._

// H2O GBM Reference 
import h2oContext.implicits._
import _root_.hex.tree.gbm.GBM
import _root_.hex.tree.gbm.GBMModel.GBMParameters

// H2O Grid Search
import _root_.hex.grid.{GridSearch}
import _root_.hex.grid.GridSearch
import _root_.hex.ScoreKeeper

// H2O data processing
import water.Key
import scala.collection.JavaConversions._

Now, we will register the required data files to SparkContext


Import all year airline data as an H2O frame into H2O:

val airlinesData = new H2OFrame(new File(SparkFiles.get("year2005.csv.gz")))

Import weather data as Spark text file and process it to convert it into a table:

val wrawdata = sc.textFile(SparkFiles.get("Chicago_Ohare_International_Airport.csv"),8).cache()
val weatherTable = wrawdata.map(_.split(",")).map(row => WeatherParse(row)).filter(!_.isWrongRow())

Transfer data from H2O to Spark DataFrame and process it Chicago O'Hare International Airport (ORD):

val airlinesTable = h2oContext.asDataFrame(airlinesData).map(row => AirlinesParse(row))
val flightsToORD = airlinesTable.filter(f => f.Dest==Some("ORD"))

Perform a SQL koin on both tables by using Spark SQL to join flight and weather data:

val bigTable = sqlContext.sql(
 |FROM FlightsToORD f
 |JOIN WeatherORD w
 |ON f.Year=w.Year AND f.Month=w.Month AND f.DayofMonth=w.Day""".stripMargin)

Set the IsDepDelayed response variable to categorical so we can perform classification machine learning on our given data set:

val trainFrame:H2OFrame = bigTable
withLockAndUpdate(trainFrame){ fr => fr.replace(19, fr.vec("IsDepDelayed").toCategoricalVec)}

// Displaying rows and columns counts in the given training data

Start the GBM Model build process:

val gbmParams = new GBMParameters()
gbmParams._train = trainFrame
gbmParams._response_column = 'IsDepDelayed

import _root_.hex.genmodel.utils.DistributionFamily
gbmParams._distribution = DistributionFamily.bernoulli

val gbm = new GBM(gbmParams,Key.make("gbmModel.hex"))
val gbmModel = gbm.trainModel.get
// Same as above
// val gbmModel = gbm.trainModel().get()

Our model is ready now, so we will be using this model to estimate departure delay (true or false probablity) on the given training data:

val predGBMH2OFrame = gbmModel.score(trainFrame)('predict)
val predGBMFromModel = asRDD[DoubleHolder](predGBMH2OFrame).collect.map(_.result.getOrElse(Double.NaN))

Define the hyperparamters for our GBM model:

// @Snippet
import _root_.hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria

val gbmHyperSpace: java.util.Map[String, Array[Object]] = Map[String, Array[AnyRef]](
"_ntrees" -> (1 to 10).map(v => Int.box(100*v)).toArray,
"_max_depth" -> (2 to 7).map(Int.box).toArray,
"_learn_rate" -> Array(0.1, 0.01).map(Double.box),
"_col_sample_rate" -> Array(0.3, 0.7, 1.0).map(Double.box),
"_learn_rate_annealing" -> Array(0.8, 0.9, 0.95, 1.0).map(Double.box)

Set up the stopping criteria so if the model is not improving the grid search and start the actual Grid Search:

val gbmGrid = GridSearch.startGridSearch(Key.make("gbmGridModel"),
 new GridSearch.SimpleParametersBuilderFactory[GBMParameters],

Here are some of the statistics you can get from the GBM grid we have just created:

// Training Frame Info

// Looking at gird models by Keys
val mKeys = gbmGrid.getModelKeys()
gbmGrid.createSummaryTable(mKeys, "mse", true);
gbmGrid.createSummaryTable(mKeys, "rmse", true);

// Model Count

// All Models
val ms = gbmGrid.getModels()
val gbm =ms(0)
val gbm =ms(1)
val gbm =ms(2)

// All hyper parameters

That's it. Enjoy!

ai ,gbm ,h2o ,machine learning ,scala ,tutorial

Published at DZone with permission of Avkash Chauhan , DZone MVB. See the original article here.

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}