Getting Started With TensorFlow Using Java/JavaScript
There is a great market for machine learning and deep learning enthusiasts. But unfortunately, it also requires different programming skillset than most developers have.
Join the DZone community and get the full member experience.
Join For FreeThis article will help get experienced Java/JavaScript developers use machine learning/deep learning in the TensorFlow ecosystem without any prior knowledge of Python.
Background
Nowadays, there is a great market for machine learning and deep learning enthusiasts, as it is a niche technology, but it also requires different programming skillset.
There are plenty of tools and libraries available for this like Numpy, Keras, and TensorFlow, but most of them play around R, MATLAB, and, of course, Python.
So, now the problem is reskilling the existing programmer community, which is typically trained on languages like C, C++, and Java.
Solution
There is lots of misconception that TensorFlow is made only for Python and its allies. Don't worry, if you are an experienced Java/JS programmer and want to get your hands dirty with ML/DL, you still can!
The solution is TensorFlow for Java and TensorFlow.js.
Learn by Example
Java
To start with Java, you need to first to set up a Maven project and add the following dependency to your pom.xml.
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.7.0</version>
</dependency>
This is an enumeration of two libraries:
- libtensorflow
- libtensorflow_jni
This wraps the TensorFlow C++ library and JNI connector to access it through Java programs.
Additionally, you can add the following dependency to support GPU acceleration. By default, it works on CPU powers.
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.7.0</version>
</dependency>
Now, you are all set to use TensorFlow features on Java. You can use following test program to check the environment.
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
public class HelloTF {
public static void main(String[] args) throws Exception {
try (Graph g = new Graph()) {
final String value = "Hello from " + TensorFlow.version();
// Construct the computation graph with a single operation, a constant
// named "MyConst" with a value "value".
try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
}
// Execute the "MyConst" operation in a Session.
try (Session s = new Session(g);
Tensor output = s.runner().fetch("MyConst").run().get(0)) {
System.out.println(new String(output.bytesValue(), "UTF-8"));
}
}
}
}
This is basically creating an operation graph with a const and assigning a value as the current version. For fetching it, you need to run it inside a session.
JavaScript
To start with JavaScript, you just need a single package available through CDN. It's available through an NPN package, as well, but it is easier to use CDN for working with a browser.
Add the below script tag to your HTML file.
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.10.0"> </script>
That's it. Now, you can test TensorFlow features like training, re-training, and converting existing models into JavaScript-compatible models right into your browser.
To test, you can use following code.
// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// Prepare the model for training: Specify the loss and the optimizer.
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// Train the model using the data.
model.fit(xs, ys).then(() => {
// Use the model to do inference on a data point the model hasn't seen before:
result = model.predict(tf.tensor2d([5], [1, 1]));
});
This is a basic linear regression model trained with mean square error and gradient descent.
You can have a look at a complete example on GitHub.
If you enjoyed this article and want to learn more about TensorFlow, check out this collection of tutorials and articles on all things TensorFlow.
Opinions expressed by DZone contributors are their own.
Trending
-
Microservices With Apache Camel and Quarkus (Part 3)
-
Health Check Response Format for HTTP APIs
-
Step Into Serverless Computing
-
Performance Comparison — Thread Pool vs. Virtual Threads (Project Loom) In Spring Boot Applications
Comments