How to Use Spark MLlib With Splice Machine
Using the Spark MLlib with Splice Machine involves creating a class with an API, writing a custom procedure, and using the Spark library with the RDD.
Join the DZone community and get the full member experience.
Join For FreeOne of the great features of Spark is that a large number of libraries have been (and continue to be) developed for use with Spark. This topic provides an example of interfacing to the Spark Machine Learning library (MLlib).
You can follow a similar path to interface with other Spark libraries, which involves these steps:
- Create a class with an API that leverages functionality in the Spark library you want to use.
- Write a custom procedure in your Splice Machine database that converts a Splice Machine result set into a Spark Resilient Distributed Dataset (RDD).
- Use the Spark library with the RDD.
Using Spark MLlib With Splice Machine Statistics
This section presents the sample code for interfacing Splice Machine with the Spark Machine Learning Library (MLlib), in these subsections:
- About the Splice Machine
SparkMLibUtils
Class API describes theSparkMLibUtils
class that Splice Machine provides for interfacing with this library. - Creating Our
SparkStatistics
Example Class summarizes theSparkStatistics
Java class that we created for this example. - Run a Sample Program to Use Our Class shows you how to define a custom procedure in your database to interface to the SparkStatistics class.
About the Splice Machine SparkMLibUtils Class API
Our example makes use of the Splice Machine com.splicemachine.example.SparkMLibUtils
class, which you can use to interface between your Splice Machine database and the Spark Machine Learning library.
Here’s are the public methods from the SparkMLibUtils
class:
public static JavaRDD<LocatedRow> resultSetToRDD(ResultSet rs) throws StandardException;
public static JavaRDD<Vector> locatedRowRDDToVectorRDD(JavaRDD<LocatedRow> locatedRowJavaRDD, int[] fieldsToConvert) throws StandardException;
public static Vector convertExecRowToVector(ExecRow execRow,int[] fieldsToConvert) throws StandardException;
public static Vector convertExecRowToVector(ExecRow execRow) throws StandardException;
resultSetToRDD
converts a Splice Machine result set into a Spark Resilient Distributed Dataset (RDD) object.locatedRowRDDToVectorRDD
transforms an RDD into a vector for use with the Machine Learning library. ThefieldsToConvert
parameter specifies which column positions to include in the vector.convertExecRowToVector
converts a Splice Machineexecrow
into a vector. ThefieldsToConvert
parameter specifies which column positions to include in the vector.
Creating Our SparkStatistics Example Class
For this example, we define a Java class named SparkStatistics
that can query a Splice Machine table, convert that results into a Spark JavaRDD, and then use the Spark MLlib to calculate statistics.
Our class,SparkStatistics
, defines one public interface:
public class SparkStatistics {
public static void getStatementStatistics(String statement, ResultSet[] resultSets) throws SQLException {
try {
// Run sql statement
Connection con = DriverManager.getConnection("jdbc:default:connection");
PreparedStatement ps = con.prepareStatement(statement);
ResultSet rs = ps.executeQuery();
// Convert result set to Java RDD
JavaRDD<LocatedRow> resultSetRDD = ResultSetToRDD(rs);
// Collect column statistics
int[] fieldsToConvert = getFieldsToConvert(ps);
MultivariateStatisticalSummary summary = getColumnStatisticsSummary(resultSetRDD, fieldsToConvert);
IteratorNoPutResultSet resultsToWrap = wrapResults((EmbedConnection) con, getColumnStatistics(ps, summary, fieldsToConvert));
resultSets[0] = new EmbedResultSet40((EmbedConnection)con, resultsToWrap, false, null, true);
}
catch (StandardException e) {
throw new SQLException(Throwables.getRootCause(e));
}
}
We call the getStatementStatistics
from custom procedure in our database, passing it an SQL query.getStatementStatistics
performs the following operations:
- Query your database. Connect to your database and run the query:
Connection con = DriverManager.getConnection("jdbc:default:connection"); PreparedStatement ps = con.prepareStatement(statement); ResultSet rs = ps.executeQuery();
- Convert the query results into a Spark RDD.
JavaRDD<LocatedRow> resultSetRDD = ResultSetToRDD(rs);
- Calculate statistics. Use Spark to collect statistics for the query using private methods in the
SparkStatistics
class:
You can view the implementations of theint[] fieldsToConvert = getFieldsToConvert(ps); MultivariateStatisticalSummary summary = getColumnStatisticsSummary(resultSetRDD, fieldsToConvert);
getFieldsToConvert
andgetColumnStatisticsSummary
methods in the Appendix: The SparkStatistics Class section at the end of this topic. - Return the results.
IteratorNoPutResultSet resultsToWrap = wrapResults((EmbedConnection) con, getColumnStatistics(ps, summary, fieldsToConvert)); resultSets[0] = new EmbedResultSet40((EmbedConnection)con, resultsToWrap, false, null, true);
Run a Sample Program to Use Our Class
Follow these steps to run a simple example program to use the Spark MLlib library to calculate statistics for an SQL statement.
- Create Your API Class. The first step is to create a Java class that uses Spark to generate and analyze statistics, as shown in the previous section.
- Create your custom procedure. First we create a procedure in our database that references the
getStatementStatistics
method in our API, which takes an SQL query as its input and uses Spark to calculate statistics for the query using MLlib:CREATE PROCEDURE getStatementStatistics(statement varchar(1024)) PARAMETER STYLE JAVA LANGUAGE JAVA READS SQL DATA DYNAMIC RESULT SETS 1 EXTERNAL NAME 'com.splicemachine.example.SparkStatistics.getStatementStatistics';
- Create a very simple table to illustrate use of our procedure:
create table t( col1 int, col2 double); insert into t values(1, 10); insert into t values(2, 20); insert into t values(3, 30); insert into t values(4, 40);
- Call your custom procedure to get statistics. Calling your custom procedure sends an SQL statement to the
SparkStatistics
class we created to generate a result set:call splice.getStatementStatistics('select * from t');
Appendix: The SparkStatistics Class
Here’s the full code for our SparkStatistics
class:
package com.splicemachine.example;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import com.splicemachine.db.iapi.error.StandardException;
import com.splicemachine.db.iapi.sql.Activation;
import com.splicemachine.db.iapi.sql.ResultColumnDescriptor;
import com.splicemachine.db.iapi.sql.execute.ExecRow;
import com.splicemachine.db.iapi.types.DataTypeDescriptor;
import com.splicemachine.db.iapi.types.SQLDouble;
import com.splicemachine.db.iapi.types.SQLLongint;
import com.splicemachine.db.iapi.types.SQLVarchar;
import com.splicemachine.db.impl.jdbc.EmbedConnection;
import com.splicemachine.db.impl.jdbc.EmbedResultSet40;
import com.splicemachine.db.impl.sql.GenericColumnDescriptor;
import com.splicemachine.db.impl.sql.execute.IteratorNoPutResultSet;
import com.splicemachine.db.impl.sql.execute.ValueRow;
import com.splicemachine.derby.impl.sql.execute.operations.LocatedRow;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary;
import org.apache.spark.mllib.stat.Statistics;
import java.sql.*;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Types;
import java.util.List;
public class SparkStatistics {
private static final ResultColumnDescriptor[] STATEMENT_STATS_OUTPUT_COLUMNS = new GenericColumnDescriptor[]{
new GenericColumnDescriptor("COLUMN_NAME", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.VARCHAR)),
new GenericColumnDescriptor("MIN", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("MAX", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("NUM_NONZEROS", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("VARIANCE", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("MEAN", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("NORML1", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("MORML2", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.DOUBLE)),
new GenericColumnDescriptor("COUNT", DataTypeDescriptor.getBuiltInDataTypeDescriptor(Types.BIGINT)),
};
public static void getStatementStatistics(String statement, ResultSet[] resultSets) throws SQLException {
try {
// Run sql statement
Connection con = DriverManager.getConnection("jdbc:default:connection");
PreparedStatement ps = con.prepareStatement(statement);
ResultSet rs = ps.executeQuery();
// Convert result set to Java RDD
JavaRDD<LocatedRow> resultSetRDD = ResultSetToRDD(rs);
// Collect column statistics
int[] fieldsToConvert = getFieldsToConvert(ps);
MultivariateStatisticalSummary summary = getColumnStatisticsSummary(resultSetRDD, fieldsToConvert);
IteratorNoPutResultSet resultsToWrap = wrapResults((EmbedConnection) con, getColumnStatistics(ps, summary, fieldsToConvert));
resultSets[0] = new EmbedResultSet40((EmbedConnection)con, resultsToWrap, false, null, true);
} catch (StandardException e) {
throw new SQLException(Throwables.getRootCause(e));
}
}
private static MultivariateStatisticalSummary getColumnStatisticsSummary(JavaRDD<LocatedRow> resultSetRDD,
int[] fieldsToConvert) throws StandardException{
JavaRDD<Vector> vectorJavaRDD = SparkMLibUtils.locatedRowRDDToVectorRDD(resultSetRDD, fieldsToConvert);
MultivariateStatisticalSummary summary = Statistics.colStats(vectorJavaRDD.rdd());
return summary;
}
/*
* Convert a ResultSet to JavaRDD
*/
private static JavaRDD<LocatedRow> ResultSetToRDD (ResultSet resultSet) throws StandardException{
EmbedResultSet40 ers = (EmbedResultSet40)resultSet;
com.splicemachine.db.iapi.sql.ResultSet rs = ers.getUnderlyingResultSet();
JavaRDD<LocatedRow> resultSetRDD = SparkMLibUtils.resultSetToRDD(rs);
return resultSetRDD;
}
private static int[] getFieldsToConvert(PreparedStatement ps) throws SQLException{
ResultSetMetaData metaData = ps.getMetaData();
int columnCount = metaData.getColumnCount();
int[] fieldsToConvert = new int[columnCount];
for (int i = 0; i < columnCount; ++i) {
fieldsToConvert[i] = i+1;
}
return fieldsToConvert;
}
/*
* Convert column statistics to an iterable row source
*/
private static Iterable<ExecRow> getColumnStatistics(PreparedStatement ps,
MultivariateStatisticalSummary summary, int[] fieldsToConvert) throws StandardException {
try {
List<ExecRow> rows = Lists.newArrayList();
ResultSetMetaData metaData = ps.getMetaData();
double[] min = summary.min().toArray();
double[] max = summary.max().toArray();
double[] mean = summary.mean().toArray();
double[] nonZeros = summary.numNonzeros().toArray();
double[] variance = summary.variance().toArray();
double[] normL1 = summary.normL1().toArray();
double[] normL2 = summary.normL2().toArray();
long count = summary.count();
for (int i= 0; i < fieldsToConvert.length; ++i) {
int columnPosition = fieldsToConvert[i];
String columnName = metaData.getColumnName(columnPosition);
ExecRow row = new ValueRow(9);
row.setColumn(1, new SQLVarchar(columnName));
row.setColumn(2, new SQLDouble(min[columnPosition-1]));
row.setColumn(3, new SQLDouble(max[columnPosition-1]));
row.setColumn(4, new SQLDouble(nonZeros[columnPosition-1]));
row.setColumn(5, new SQLDouble(variance[columnPosition-1]));
row.setColumn(6, new SQLDouble(mean[columnPosition-1]));
row.setColumn(7, new SQLDouble(normL1[columnPosition-1]));
row.setColumn(8, new SQLDouble(normL2[columnPosition-1]));
row.setColumn(9, new SQLLongint(count));
rows.add(row);
}
return rows;
}
catch (Exception e) {
throw StandardException.newException(e.getLocalizedMessage());
}
}
private static IteratorNoPutResultSet wrapResults(EmbedConnection conn, Iterable<ExecRow> rows) throws
StandardException {
Activation lastActivation = conn.getLanguageConnection().getLastActivation();
IteratorNoPutResultSet resultsToWrap = new IteratorNoPutResultSet(rows, STATEMENT_STATS_OUTPUT_COLUMNS,
lastActivation);
resultsToWrap.openCore();
return resultsToWrap;
}
}
And that's it!
Published at DZone with permission of Gary Hillerson, DZone MVB. See the original article here.
Opinions expressed by DZone contributors are their own.
Comments