The document discusses optimizing and deploying PyTorch models for production use at scale. It covers techniques like quantization, distillation, and conversion to TorchScript to optimize models for low latency inference. It also discusses deploying optimized models using TorchServe, including packaging models with MAR files and writing custom handlers. Key lessons were that a distilled and quantized BERT model could meet latency SLAs of <40ms on CPU and <10ms on GPU, and support throughputs of 1500 requests per second.
Axa Assurance Maroc - Insurer Innovation Award 2024
Serving BERT Models in Production with TorchServe
1. Serving BERT in production
https://bit.ly/pytorch-workshop-2021
Image Ref
2. Agenda
1. Machine Learning at Walmart
2. Review some DL concepts + Hands-on
3. Optimizing the model for production + Hands-on
4. Break
5. Deploying the model + Hands-on
6. Lessons Learned
7. Q/A
6. ML in Search: Query / Item Understanding
Query: Nike Men Shoes
Brand: NIKE
Gender: Men
Product Type:
- Athletic Shoes : 99%
- Casual and Dress Shoes: 55 %
Item
Brand: Nike
Color: Black / White
Shoe Size: 7.5M
7. What were the requirements
- Current production models written in TensorFlow and served with Java Native
Interface
- serve large models like BERT
- Support Pytorch / Onnx
- SLA < 40 ms
14. Models before BERT
- Word Embedding Layer (Word2Vec,
GLOVE, ELMO)
- Bidirectional LSTM
Cons:
- Slow to train RNN cells due to recurrence
- Not deeply bi-directional and hence
context is lost, since left to right and right
to left contexts are learnt separately and
then concatenated
- Transformer models help alleviate these
Image Ref
15. BERT
Source:
Source : BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding
Some Use Cases:
- Neural Machine Translation
- Question Answering
- Sentiment Analysis
Problems Solved:
- Contextualized Embeddings
- Can be trained on a larger corpus in a semi-supervised
manner;
- General Language Model trained on Masked Word
prediction
- Tokenizer solves for out of vocab words
Architecture:
- Attention is all you need
- stacked Transformer’s Encoder model.
16. BERT: Different Architectures
For environments with limited compute resources, the BERT recipe has different model sizes with different
attention heads and layers of complexity. The following table shows some of them with their corresponding GLUE
scores.
MODEL H L SST-2 QQP
BERT Tiny 128 2 83.2 62.2/83.4
BERT Mini 256 4 85.9 66.4/86.2
BERT Small 512 4 89.7 68.1/87.0
BERT Medium 512 8 89.6 69.6/87.9
17. Bi-LSTM vs BERT
BI-LSTM Predictions:
- Mattress: 0.19
- Mattress Encasements: 0.10
- Bed Sheets: 0.05
BERT Predictions:
- Bed Sheets: 0.72
- Bedding Sets
- Bed-in-a-Bag: 0.26
Query: bedding for twin xl mattress
F1 Score improved from 0.75 to 0.91
19. Tokenization (pre-processing)
- BERT uses WordPiece
Tokenizer
- A token is word or pieces of a
word
- Vocab size is smaller than glove
( 30k vs 300k)
24. Post Training Optimization: Quantization
- Techniques for performing computations
and storing tensors at lower bit than
floating point precision; reduce memory
and speed up compute
- Hardware support for INT8 computations
is typically 2 to 4 times faster compared to
FP32 compute.
Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT
25. Post Training Optimization: Quantization
- Dynamic Quantization:
- Weights are quantized
- Activations quantized on the fly (don’t get memory benefits)
- However, the activations are read and written to memory in floating point format
Dynamic Quantization in PyTorch
26. Post Training Optimization: Quantization
- Static Quantization:
- Weights / activations are quantized with a calibration data post training
- Passing quantized values between operations instead of converting these values to floats -
and then back to ints - between every operation, resulting in a significant speed-up.
- Quantization Aware Training:
- quantization error is part of training error; training learns to reduce quantization error
- Leads to the greatest accuracy
27. Quantization results
Model Fp32
accuracy
Int 8
accuracy
change
Strategy Inference
performance
gain
Notes
BERT 0.902 0.895 Dynamic 1.8x
581 -> 313
Single thread;
1.6HZ;
Resnet-50 76.9 75.9 Static 2x
214->103
Single thread;
1.6GHZ
Mobilenet-v2 71.9 71.6 Quantization
Aware
Training
5.7x
97->17
Samsung S9
Reference: Introduction to Quantization on PyTorch ; pytorch.org
28. Post Training Optimization: Distillation
- Larger models have higher knowledge
capacity than small models but not all of it
may be used
- Transfer Knowledge from a trained large
model to a smaller one
Image Reference: https://towardsdatascience.com/knowledge-distillation-simplified-
dd4973dbc764
29. Distillation Results
Distilling Task-Specific Knowledge from BERT into Simple Neural Networks, 2018, Tang et all
Model SST-2 QQP
BERT (Large) 94.9 72.1
BERT Base 93.5 71.2
ELMO 90.4 63.1
BiLSTM 86.7 63.8
Distilled
BiLSTM
90.7 68.2
30. Eager Execution vs Script Mode
Eager Execution
- Good for training /prototyping / debugging
- requires a python process
- Slow for production use cases
- Flexible ; can apply any model techniques;
easy to change
- Can use any of the libraries in python
ecosystem
Script Mode
- Model is running in graph execution mode;
better optimized for speed
- Only needs a C++ runtime; so can be
loaded in other languages
- TorchScript is an optimized intermediate
representation of a pytorch model
- Doesn’t support everything that python
/eager model supported
31. TorchScript JIT: Tracing vs Scripting
Tracing
- Does not support python data structures /
control flows in the model
- Requires a sample input
Scripting
- Supports custom control flow and python
data structures
- Only requires the model
33. TorchScript Timing
- Inference time for torch script model: cpu/gpu
CPU (ms) GPU (ms)
Pytorch 1339 460
TorchScript 768 (42%) 360 (21%)
Benchmarking Transformers: PyTorch and TensorFlow from HuggingFace
Full Table
34. Hands-on: Optimizing the model
- Quantizing model
- Converting the BERT model with torch script
https://bit.ly/pytorch-workshop-2021
Notebook: 03_optimizing_model
36. Options for deploying PyTorch model
- Pure Flask / Fastapi
- Model Servers: TorchServe / Nvidia Triton / TensorFlow Serving
37. Benefits of TorchServe
- Optimized for serving models
- Configurable support for CUDA and GPU environments
- multi model serving
- model version for A/B testing
- server side batching
- support for pre and post processing
38. Packaging a model / MAR
- Torch Model Archiver: utility from pytorch team
- package model code and artifacts into one file
- Python model specific dependencies
39. PyTorch Base Handler
Model Serving Handler at
● Acts as a layer on top of the BaseHandler provided by
TorchServe
● All common initialization, pre-processing, error handling and
metric collection steps can be abstracted into this custom
handler class
● Any model specific logic can be added in the child class that
inherits from the custom handler class
● This provides a resilient wrapper that can be extended and
used by multiple teams to serve models through TorchServe,
with minimal config and code changes
40. Built in handlers
- Pytorch has support for these handlers / use case
- image_classifier
- image_segmenter
- object_detector
- text_classifier
- Just need to provide model_file and index_to_name.json
42. APIS
A. Management Api: http://localhost:8081
List of Models:
- curl "http://localhost:8081/models"
Get detail on model
- curl http://localhost:8081/models/model_name
A. Inference Api: http://localhost:8080
Inference:
curl -X POST http://localhost:8080/predictions/model_name
43. Handson: Deploying the Model
- Package the given model using Torch Model Archive
- Write a custom handler to support preprocessing and
post processing
https://bit.ly/pytorch-workshop-2021
Notebook: 04_packaging
45. Lesson Learned
- DistilledBertBase with quantization and other optimizations can be served on a CPU and meet < 40 ms SLA
- DistilledBertBase on nvidia-t4 can hit < 10 ms under peak traffic
- Throughput of 1500 RPS at under SLA
46. Q/A
Check out Walmart Global Tech for openings for Software Engineers, Machine
Learning Engineers and Data Scientists: https://careers.walmart.com/technology
Editor's Notes
Intro (10 mins)
Setup
Introduce the deep learning BERT model- classification models at walmart
Walk over the notebooks on Google Collab Setup
Show the end model served
Review Some Deep Learning Concepts (10 mins)
Review sample trained pytorch model code
Review sample model transformer architecture
Tokenization / pre and post processing
https://jalammar.github.io/illustrated-bert/
Transition:
that was a brief view of the HF ecosystem
how to use inference with a trained model
I hope you saw that large models have a pretty big compute cost
lets look into strategies that are used for reducing the time
We will look into 3 optimization techniques
Break off here for a hands on
Transition:
First technique , we will look at is quantization
Content:
when we trained our model, we trained it with 32 bit precision floats
for inference, do we need such precision? most of the the time , we can convey the same information with 16 bit or 8 bit ints;
Not only do we reduce the storage/ memory requirements.. But on most hardware, int8 computations are faster
In the image, on the top half we have the full range -> we transform it to the below
Transition:
In pytorch, we have 3 different types
Content:
3 main quantization techniques
All quantization techniques, you quantize the weights
In dynamic quantization, the activations are quantized dynamically on the fly for fast int8 computatioon.. But reade/written back in memory using floats
Transition:
Lets look at the other two techniques
Content:
In DQ, we didn’t quantize the weights
In SQ, we will quantize the activations by giving it calibration data ; because the activations are also quantized, we avoid the constant float to int conversion.. Which further speeds inference
If you know, u will need quantization you can optimize for quantization while training.. By taking quantization error as part of training error .. leads to best accuracy
Transition:
Lets look at the results of these techniques in practice
Content:
BERT; 0.5% accuracy drop ; but we have have 1.8x speedup (581 ms - > 313 ms)
Resnet: we saw a 1% accuracy drop; but got a 2x speedup with static quantization
Mobilenet: saw a 0.3% drop with quantization aware training; it saw 5x improvement when running on mobile phone where floating point calculation are slower
Quantization is a powerful technique especially on non GPU environments
Transition:
Next optimization, we will look at is Distillation
Content:
Large Models like BERT have large knowledge capacity , but do you need it all for your problem
Then you might have a smaller model like a BI-LSTM
What if you could teach the smaller model to learn from the larger model’s predictions and loss
Transition:
Lets see the results of this in practice
Content:
On the stanford sentiment task, we got an accuracy of 93.5% for bert base
A regular BI-LSTM model only got 86.7%
After distillation, the accuracy jumped by 4%
Transition:
Lets look at the last optimization technique
Content:
One of the reasons why pytorch became popular , is that it by default runs in Eager execution mode…
Eager:
Which allows us to get immediate feedback from operations on our layers and great debugging support
You can also use any libraries in python
But this flexibility, makes it a bit slow for production and makes it tied to a python runtime
Script mode
In script, you created an optimized execution graph of your model which is optimized for speed
Read points
Transition:
Lets learn how we can convert a model to Script mode using pytorch’s Just in time compiler
Content:
The pytorch just in time compiler, support two mode: tracing / scripting
Read points …
Transition:
Lets see a code snippet
Content:
Because we have an if/esle block, we need to use script mode
Transition:
Lets see how much improvement we got
Content:
On CPU, we see a 42% improvement
We see on a GPU, that we got a 21% speedup
Transition:
We covered three techniques
Lets explore it a lab
Package the given model using Torch Model Archive
Write a custom handler to support preprocessing and post processing