Thank you for having us tonight.
I’m Andrew Ayres, this is Qing Lan and we’re going to introduce you to “Using Java to deploy Deep Learning models with MXNet”
Let’s quickly go through what we plan to cover and what’s in it for you.
We’ll cover some background explaining what deep learning is and why you should care.
We’ll cover what is Apache MXNet
Show you how to get started using MXNet Java
Go over some of the more interesting technical challenges we’ve had
And finally, we’ll have some time for some Q&A
So by the end of this session, you will understand what deep learning is, and how you can get started using it.
So, what is deep learning.
Let’s start with the basics – AI -> ML -> DL
The Term AI was coined in the 50s, and Alan Turing, often referred to as the father of computer science, changed the fundamental question from "can machines think" to "can machines imitate us humans" - and suggested the famous Turing test.
AI is a broad field, that is covering areas of philosophy and science.
ML, that came in the 80s and saw massive adoption in the 90s and early 2000s, is a sub field in AI, and studies algorithms that allows machines to learn patterns from data. Examples are algorithms like, regression, decision trees, support vector machines, and genetic algorithms.
ML introduced a major shift in how programming is done. Classic programming was about coding explicit rules, and then running these rules on top of input data, to produce results.
With ML, however, this paradigm changes. We provide the ML algorithm with labeled data: input and results, and the machine learns the business rules by itself. That's a pretty massive paradigm shift.
And then there is deep learning, which is a specific ML technique. Traditionally, this would of probably been referred to as Neural Networks.
DL is what we’ll be talking about today and is particularly interesting because it has been outperforming other ML techniques across a growing number of problems.
You can see Deep Learning applied in more and more domains, with a growing impact on our lives.
If you look at the breadth of AI applied within Amazon alone, you can see DL in the Retail Website within personalization and recs, you can see it optimizing Amazon’s logistics, you probably noticed the boom voice-enabled personal assistants, and you may have heard that Amazon drones also rely on deep learning, just as other autonomous vehicles tech is relying on it. And of course the list goes on.
Beyond the growing usage of DL in applications and devices around us, there is another interesting aspect to deep learning, and that is that for some tasks, it’s able to outperform us.
One of the first areas Deep Learning was able to demonstrate state of the art results, was in the domain of Computer Vision. A classical problem in that domain is Object Classification: given an image, identify the most prominent object in that image out of a set of pre-defined classes. A DNN presented in 2012 by Alex Krizhevsky, was able to leap-frog the best known algo to date by over 30%. That was really a major leap, and since then every year the best algorithm for Object Classification, and many other Vision tasks, are based on Deep Learning, with results that keep on getting better.
It’s been shown that DNNs already outperform humans in Object Classification. This is quite remarkable as this is a task that we’ve specialized in due to evolution. If you introduce noise into those images we still reign supreme. For now at least.
AlexNet paper: https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
Humans vs DNNs paper: https://arxiv.org/pdf/1706.06969.pdf
OK, hopefully now you are convinced that DL is more than just hype.
Now we get to do something fun and challenging. We’re going to take an entire field, with decades of research, and I’m going to teach it to you in about 10 minutes.
We’ll start with the basic artificial neuron.
Artificial neurons are inspired by neurons of the human brain. We have 100B of them, and 1 quadrillion synapses interconnecting these neurons.
But in fact, ANN is a simple computation construct: numeric inputs are fed into the neuron, where they are multiplied by weights, and summed into a single numeric value, that then goes through a non linear activation function, which returns a single numeric scalar typically in the range of zero to one.
Neural networks are composed of neurons. The neurons are organized into layers. There’s the input layer, the output layer, and in between is what we refer to as the hidden layers.
If there are ”many” hidden layers, we call the network deep. Hence the term Deep Learning.
Since the neurons in the network have a non-linear activation function, the whole network is non-linear and able to approximate complex functions.
Able to learn hierarchical features about the data
Scalable architecture. You can increase the number of layers. Increase the number of neurons in a layer…
But this gets computationally expensive.
(read)
Let’s go briefly into some different types of layers.
Most straight forward type is the fully connected layer.
Fully connected simply means that each neuron in the layer has a connection with each neuron in the preceding layer
The next common layer type is the Convolution layer.
These are often used when working with images or videos.
In a convolutional layer, the operator is a kernel which is moved across the input. At each position, we perform element-wise multiplication between that input matrix within the sliding kernel and the matrix of kernel values.
The kernel values are the learned paramters of the layer.
Why do this instead of fully connected layers?
Reduces the dimensionality of the problem. Images have very large input sizes. For example, in a fully connected layer: even for a small 100x100 pixel image each neuron in the second layer would have 10,000 weights
This approach also preserves spatial correlations. When we look at an image, we don’t look at it pixel by pixel. My son has this tendency, when he has something to show me he puts it right in my face. I have to pull back, because not only do I not like playdoh in my face, but it’s hard to see things when you’re too close.
Convolution layers follow a similar principle. They preserve the spatial correlations and allow us to look at more than just a pixel.
When we combine multiple convolution layers, we are to learn increasingly abstract features.
This illustration shows input image patches that activate certain kernels the most – showcasing how the conv layers are handling the input data
Convolutional layer 1 can detect edges
Layer 2 can use those edges to detect curves
Eventually, you can use these abstract features to do things like classification
The final type of layer we’ll go over is the recurrent layer.
Consider this example, if we’ve got a function which looks like this..how do we know whether we’re in position A or B?
For some problems, you’re current state isn’t enough and it’s important to know how you got there..
The output of the layer is fed back into the layer. Essentially, it remembers the past by using loops.
Some examples of this are time series, speech recognition and natural language problems.
TODO: Improve look
We can unroll an RNN layer to better understand how it works
Each of these is a timestep.
The difficult part is training the network, so we can find the right weights that will approximate the function modeling the problem we are trying to solve.
We start with the “Forward Pass”, in which we take a sample from our labeled input data, feed it through the network to get the inference, or prediction result.
We then do the “Backwards Pass”, also called “Backprop”, where we calculate the loss, i.e. how bad did the network did compared to the “Ground Truth” – the label of the sample input data – and then we back propagate the loss across the network, finding the gradients of each weight to identify the direction of the error.
We then update the weights across the network, in a direction opposite to the gradient, and in a value that is typically a fraction of the gradient – this fraction is called the “Learning Rate”.
The Backwards Pass is where learning happens. Through repeated iterations, we are leveraging the gradient to take down the loss, until we converge into a low error rate.
After we’ve finished training the model, we no longer do the backwards pass. Instead we do only the forward pass to make predictions. We refer to this process as inference.
Just a bit of background on MXNet:
MXNet is a deep learning framework for building, training, and deploying Deep Neural Nets.
This last part, deploying DNNs, is probably the most interesting to you and what we’ll be focusing on today.
MXNet is an Apache project. This means that no one group or company controls the project. Decisions about the direction of MXNet are made by the MXNet community.
It originated in the academia, CMU and UW
Aws adopted MXNet late 2016 as “DL FW of choice), there’s a nice blog post by AWS CTO (Vogels) explaining more in details. A lot of it is about scalability and MXNet being good for production use.
Speaking of community
This is by no means a complete list, but there are a lot of companies that use MXNet, that are part of the MXNet community, and that we get to collaborate with on the project.
It’s exciting to see the project grow.
MXNet support 8 languages, including 3 JVM languages.
Highly performant and scales linearly
ONNX support. ONNX is open neural network exchange format for models spear headed by AWS, Facebook and Microsoft. Frameworks use their own formats to save models, models are your intellectual property and you don’t want to be married to a framework, MXNet supports ONNX out of the box so you can bring models from other frameworks that support ONNX or take your MXNet trained model and probably serve using other frameworks
Why did we decide to work on MXNet Java
We were seeing a common theme in the community. Most model training occurs in Python. After training is finished, what do we do with the model? Most of the time, the answer is that we want to start using it.
One problem is that the person responsible for training the network and producing the best possible model is often not the same person who will deploy that model into a production environment.
Enable Java developers to use DL in their existing Java workflows without having to become experts. This meant creating a simple, easy to use API that is focused on deploying existing models.
With that, I’ll turn it over to Qing to introduce you to MXNet Java.
Most importantly, please install Java. People are making mistake on installing different version of Java. MXNet package support Java8 to compile.
OK, hopefully now you are convinced that DL is more than just hype, and are excited about it and want to learn more. Let’s get to it.
If you remember from earlier, MXNet has support for many languages.
The architecture for JVM languages is currently set up like this….
By designing Java as a wrapper around the Scala API it allowed us to move quickly, gauge community interest in a Java API, keep JVM efforts concentrated, and benefit from all the work already done in Scala.
At the beginning when we start implementing the Java API. We are facing great difficulties to use the operators. They need to be passed as key value argument to the C backend
In this case, users are required to understand what arguments does the function have. For some of the operators, there are more than 10 arguments which makes it severly hard to construct. It undoubtly raise the bar for new users to use the Java API. There are two ways to solve this problems, either to hand craft all methods with their argument or generate these methods. We decided to choose the latter one and that one seemed to take less time and easier to maintain.
WE actually spend more time on this approach, planning, designing, implementing than manually copy and paste implementation over
Many of the MXNet objects utilize native memory (memory in the C++ heap) to maximize performance.
JVM Garbage Collector only manages the JVM heap and is not aware of these native objects.
Since JVM languages do not have destructors, these objects much be deallocated explicitly.
In order to manage these resources effectively, a multi-pronged approach has been implemented.
Once the GC found the C++ pointer unreachable, it would release the memory in an arbitrary order. Sometimes during muti-threading it means a crash. Or sometimes the memory just explode because of the characteristic of Java Garbage Collection.
Watch the video that Naveen sent
If your memory is fine to be collected periodlcally, they can be kept in t
Meant as last guardian against memory leaks
It is a great replacement to the finalizer
Threading problems
Once the Java object lose the reference, as it could be collected, the phantom reference would place the C native pointer to a reference queue
If your service doesn’t need the memory to be collected in a timely manner, the phantom references should suffice. Otherwise… try with resource
In some cases, the JVM GC might not run often enough to sufficiently manage the native resources.
For this reason, a ResourceScope object is provided which implements AutoCloseable and tracks the MXNet Objects created within the scope.
When leaving the scope, these objects are automatically freed.