avatarTim Cvetko

Summary

The author tested Apple's new MLX framework against Torch on M2 Air, comparing training, inference, and CPU usage for the BERT model, with MLX showing better performance.

Abstract

On Tuesday, Apple's AI team released MLX, a new machine learning framework designed for Apple Silicon Chips. The author, who owns a Macbook M2 Air and regularly trains ML models locally, decided to test MLX against PyTorch using the standard BERT transformers model. The results showed that MLX outperformed PyTorch in terms of training time, inference time, and CPU utilization during training. The author also provided a quick guide into MLX, highlighting its Torch-like syntax and higher-level packages like mlx.nn and mlx.optimizers. The MLX examples repo contains examples of transformer language model training, large-scale text generation with LLaMA and finetuning with LoRA, generating images with Stable Diffusion, and speech recognition with OpenAI’s Whisper.

Bullet points

  • Apple's AI team released MLX, a new machine learning framework for Apple Silicon Chips.
  • The author tested MLX against PyTorch using the standard BERT transformers model.
  • MLX outperformed PyTorch in terms of training time, inference time, and CPU utilization during training.
  • MLX has a Torch-like syntax and higher-level packages like mlx.nn and mlx.optimizers.
  • The MLX examples repo contains examples of transformer language model training, large-scale text generation with LLaMA and finetuning with LoRA, generating images with Stable Diffusion, and speech recognition with OpenAI’s Whisper.

I Tested Apple’s New MLX Framework Against Torch on M2 Air

MLX vs Torch on BERT — Training, Inference, and CPU Usage Comparison

On Tuesday, Apple’s AI team released “MLX” — the new machine learning framework designed to work specifically for the Apple Silicon Chips. The design of MLX was inspired by frameworks like NumPy, PyTorch, Jax, and ArrayFire.

Photo by Sumudu Mohottige on Unsplash

Is MLX really faster than Torch on Mac?

As I own a Macbook M2 Air and regularly train ML models locally, I decided to put this hypothesis to the test by training the standard BERT transformers model on both MLX and PyTorch. The results are staggering!

Who should read this?

Who is this blog post useful for? Mac(M1, M2, M3) owners who are looking for a faster training & inference ML framework.

How advanced is this post? Anybody previously acquainted with ML terms should be able to follow along.

Replicate my code here: https://github.com/Timothy102/mlx-bert

Quick Guide into MLX

MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research.

  • MLX is very Torch-like in its syntax. MLX has higher-level packages like mlx.nn and mlx.optimizers with APIs that closely follow PyTorch to simplify building more complex models.
  • MLX has a Python API that closely follows NumPy. MLX also has a fully featured C++ API, which closely mirrors the Python API.pip install mlx
Image by Author

The MLX examples repo has a variety of examples, including:

Here’s THE Thing

The goal of this experiment was to test the MLX against the standard ~ 400MB BERT from HuggingFace in comparison to PyTorch. Here’s what I wanted to test:

  • Training time [sec]
  • Inference time [sec]
  • CPU Utilization during training [%]
Image by Author(Notion)

Model Setup

The MLX model was built separately to adjust for syntax change but was initialized with the same model weights as the Torch model.

Image by Author

Training Comparison

For MLX, the training process involved loading a pre-trained BERT model, updating its weights with weights converted from the PyTorch model, and evaluating the performance using synthetic data.

Image by Author

The training time was measured across varying data sizes, specifically different batch sizes, to showcase how MLX handles the training workload.

Training time per batch per framework per second

Inference

For MLX, the inference process involved loading a pre-trained BERT model, specifically the MLXBertModel, and providing it with synthetic input data. The input data consisted of randomly generated sequences, including input_ids, token_type_ids, and attention_mask, mimicking the structure of typical BERT inputs. These synthetic inputs were then passed through the MLXBertModel to measure the time taken for inference, capturing the efficiency of the MLX framework in processing BERT-based tasks.

Image by Author

On the PyTorch side, the inference setup mirrored that of MLX. A pre-trained BERT model, sourced from the Hugging Face Transformers library, was loaded into a PyTorch environment. Similarly, synthetic input data, comprising input_ids, token_type_ids, and attention_mask tensors, was generated randomly. The PyTorch BERT model processed this synthetic input data, and the inference time was recorded.

Inference Time Comparison; Data size in MBs

CPU Utilisation

During both MLX and PyTorch BERT model training, I captured CPU utilization via the Psutil library.

Image by Author
CPU Usage Per Batch Size per Framework

Conclusion

As per my final notes, the data on the comparison experiment are final.

The mlx library outperforms torch in computing on an Apple M2 Chip by an 8% CPU Utils Margin.

Final Comparison Table

These research results from Apple’s research team definitely seem promising.

Hey, thanks for reading!

Thanks for getting to the end of this article. My name is Tim, I love to elaborate ML research papers or ML applications with emphasis on business use cases.

References

Mlx
Machine Learning
Python
AI
Programming
Recommended from ReadMedium