How to Achieve a 9ms Inference Time for Transformer Models

Online user-generated content has been increasing for the past decades, and the online harmful content is increasing in both volume and diversity.

Bhaskar
Bhaskar
Published February 14, 2023
Learn How To Safeguard Your Users With Auto-Moderation

Interested in Moderation for your product? Check out Stream's Auto-Moderation Platform!

It is crucial for the technology platforms to moderate any harmful content as early as possible. Most modern moderation tools take a few hundred milliseconds to a few seconds to detect harmful content. Often the action against detected harm is taken after the harm is already done, that is, thousands or millions of end users have seen the harmful content. In this world of “Live Streaming” and “Instant Messaging,” detecting and acting on the harm in real-time is significant.

The expected response time for a real-time “Send Message API” to submit the user-generated content is less than 100ms or even 10ms for certain scenarios. Working with such a narrow time window leaves little room to moderate the content before sending it for public viewing. Thus, optimizing the Machine Learning Models to infer with low latency and high throughput is important.

We serve over two billion end users at Stream with our activity feed and in-app chat solutions. To efficiently serve such a large volume of users and content, it is necessary to have high throughput— the ability to serve millions concurrent users who send thousands of messages per second. Adding more servers for inference soon becomes increasingly costly and infeasible eventually. Thus, it is equally important to have low latency for inference to allow a single server to process more requests per second.

In this article, we explore the methods adopted by Stream to achieve high throughput and low latency for the moderation services. First, we look at selecting the model and reducing its memory footprint. Next, we look at selecting the hardware to run inference on and reducing CPU usage. Lastly, we look at the optimizations done while inferencing from the model and reducing computations. Finally, we present the results and conclude with future ideas.

Step 1: Reduce Memory Footprint of the ML Model

Transformers are state-of-the-art machine learning models for natural language processing (NLP). Unlike traditional ML models, transformers look at the message text as a whole instead of individual words or phrases. The predictions are generated using the intent and meaning of the message text, which handles missing or misspelled words. Bad actors often obfuscate the known harmful words to avoid detection by simple moderation policies like blocklists.

Bidirectional Encoder Representations from Transformers (BERT) is a language model pre-trained on over three billion words. It can be fine-tuned to recognize various natural language tasks by training additional classification layers. The BERT-base model requires more than 400MB of memory and takes a few hundred milliseconds for inference on CPU instances. This makes it unsuitable for real-time predictions in production environments.

Reducing Model Size Reducing Interface Time

Optimization 1: Knowledge Distillation

Knowledge Distillation is a technique to train a smaller model (student) from a pre-trained large model (teacher). The student model learns by mimicking the response from the teacher model. DistilBERT is the distilled version of BERT, which halves the number of layers to make it 60% faster than the BERT-base model. It requires less than 300 MB of memory and takes less than 100ms for inference on CPU instances. Thus making it suitable for real-time predictions in production environments.

Optimization 2: Model Quantization

Quantization is a technique to speed up inference by converting the floating point numbers (FP32) to lower bit widths (int8). It allows the use of high-performance vectorized operations available on many CPU hardware. A DistillBERT model quantized using ONNX (Open Neural Network Exchange) requires less than 100MB of memory and takes less than 50ms of inference time. Intel, OpenVino, and HuggingFace also provide hardware-specific quantization libraries similar to ONNX. Since the quantization technique depends on the hardware chosen, it becomes important to apply compatible quantization on the right hardware. Next, we look at the hardware chosen and the hardware-specific optimizations we applied to achieve High Throughput and Low Latency on CPU instances.

Step 2: Selecting the Right Hardware

Even BERT-base models take less than 10ms of inference time on a GPU instance but cannot be optimized further using quantization techniques. On the other hand, we achieved similar inference time with quantization and other optimizations described here on the CPU instances, which are four times cheaper than the GPU instances.

Optimization 3: Hardware-Specific Optimizations

Hourly cost of AWS instances Varying Instance type

Optimization 4: Optimizing PyTorch Threads

Varying Number of Threads

We experimented with varying threads available to the single process in a single instance. The best results were achieved with four threads (torch.set_num_threads(4)). The speedup on increasing the thread count further was insignificant. Thus, we created a horizontally scalable infrastructure that handles over 100 messages per second. Note that the number of instances also depends on the average message length since the inference time increases with the increasing length of the input message. Next, we look at the optimizations made during inference to make real-time predictions possible for varying message lengths.

Step 3: Real-time Optimizations

At Stream, we focus on reducing harmful content by “nudging” the users in real time rather than asynchronously flagging or blocking them. It is important to proactively bounce potentially harmful content to prevent it from being seen by other users. Nudging helps promote positive behavior on the online platform. Having consistently low latency is the key requirement to be able to nudge before the message is sent to the other users.

Impact on interface time when padding is skipped

Optimization 5: Skip Padding

Deep learning algorithms are designed to work with a batch of inputs and require all input messages to be the same length. Shorter messages are padded with zeros, while longer messages are truncated to ensure the same length for all input messages. For real-time inference, we just have a single input message, i.e., the batch size is one. This means we can skip padding and perform inference faster for smaller input messages. We will still truncate the longer messages to keep inference time in check. Since most chat messages are short, skipping padding allows us to have high throughput.

Production Results

Varying messages per second with two instances

To summarize, we trained a DistilBERT model and quantized it using ONNX runtime. The optimized models were deployed on CPU instances supporting AVX extensions. We used a single process working with four parallel threads on each instance. Additionally, we skipped the padding step to achieve a low latency of less than 10ms and could support a high throughput of over 100 messages per second.

Conclusions and Future Work

We consistently achieved a p50 latency of 9ms and a p99 of less than 50ms after applying the optimizations described above. In the future, we would like to also use the pruning step to optimize the model’s footprint. We have noticed an improvement of 2ms after pruning 10% of attention heads without significantly impacting the model’s performance. Also, we would like to experiment with other hardware optimizations like Intel Neural Compressor (INC) and Amazon Neuron SDK.

Try Stream Auto-moderation today, and gain a sidekick for your trust and safety teams!!