MLX Quickstart
These are my notes on the MLX quick start guide and usage notes. It's a work in progress. Ultimately, I'm interested in learning what MLX will let me do with LLMs on my laptop. I might write something more substantial on that topic in the future. For now, you're probably better off consulting the docs yourself than looking at my notes on them.
What is MLX?
MLX is an array framework from Apple ML Research. Its API follows NumPy. It has higher-level packages that follow PyTorch's API for building more complex models modularly. It features:
- composable function transformations
- lazy computation
- dynamic computational graph construction
- unified memory model. Arrays live in shared memory, and operations on MLX arrays can be performed on CPU or GPU without the need to transfer data.
Quickstart Guide
First, install with pip install mlx
or conda install -c conda-forge mlx
.
Next, we'll work through the basic features, as shown in the quickstart guide linked above.
Basics
import mlx.core as mx a = mx.array([1,2,3,4]) a.dtype
mlx.core.int32
Operations are lazy.
b = mx.array([1.0, 2.0, 3.0, 4.0]) c = a+b
c is not computed until it is explicity called or until we use eval
. See Lazy Evaluation in More Detail.
mx.eval(c)
c
array([2, 4, 6, 8], dtype=float32)
MLX has grad
like PyTorch.
x = mx.array(0.0) mx.sin(x) mx.grad(mx.sin)(x)
array(1, dtype=float32)
Lazy Evaluation in More Detail
When you perform an operation:
- No computation happens
- A compute graph is recorded
- Computation happens once an
eval()
is performed.
PyTorch uses eager evaluation. Tensorflow uses lazy evaluation. Jax uses lazy eval. Jax and TF have different approaches to when they evaluate. TF/Jax graphs are compiled while MLX graphs are built dynamically.
One LLM-relevant use case: initializing model weights. You might initialize a model with model = Model()
. The actual weight loading won't happen until you perform an eval()
. Useful if you e.g. subsequently update the model with float16
weights. You don't take the memory hit that you'd get with eager execution, loading the float32 weights.
It enables this pattern:
model = Model() # no memory used yet model.load_weights("weights_fp16.safetensors")
When to evaluate
It's a tradeoff between:
- letting graphs get too large
- not batching enough to do useful work
There's a lot of flexibility.
Luckily, a wide range of compute graph sizes work pretty well with MLX: anything from a few tens of operations to many thousands of operations per evaluation should be okay.
Example of a good pattern for a training loop:
for batch in dataset: # Nothing has been evaluated yet loss, grad = value_and_grad_fn(model, batch) # Still nothing has been evaluated optimizer.update(model, grad) # Evaluate the loss and the new parameters which will # run the full gradient computation and optimizer update mx.eval(loss, model.parameters())
Note: whenever you print an array or convert it to a numpy array, it is evaluated. Saving arrays will also evaluate them.
Using arrays for control flow will trigger an eval.
def fun(x): h, y = first_layer(x) if y > 0: # An evaluation is done here! z = second_layer_a(h) else: z = second_layer_b(h) return z
Unified Memory
You do not need to specify the location of an MLX array in memory. CPU and GPU share memory.
Instead of moving arrays to devices, you specify the device when you run an operation.
a = mx.random.normal((100,)) b = mx.random.normal((100,))
None
mx.add(a, b, stream=mx.cpu)
array([-0.999945, -0.255963, 1.04271, ..., 1.08311, -0.993303, -1.48334], dtype=float32)
mx.add(a, b, stream=mx.gpu)
array([-0.999945, -0.255963, 1.04271, ..., 1.08311, -0.993303, -1.48334], dtype=float32)
The MLX scheduler will manage dependencies to avoid race conditions. In other words, this is fine.
c = mx.add(a, b, stream=mx.cpu) d = mx.add(a, c, stream=mx.gpu)
None
This can be useful if we e.g. send compute-dense operatios to GPU, smaller overhead-bound operations to cpu like this example.
Indexing arrays
- Is the same as NumPy in most cases
- EXCEPT:
- It does not perform bounds checking. Indexing out of bounds is undefined behavior. Why? Exceptions can't propagate from the GPU.
- Boolean mask indexing is not supported (yet).
Saving and Loading
Support for numpy, numpy archive, safetensors, gguf.
Function transforms
MLX uses composable function transformations for autodiff, vectorization, graph optimization. Main idea: every transformation returns a function that can be further transformed. Here is an example.
dfdx = mx.grad(mx.sin) dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
The output of grad
on sin
is another function: the gradient of the sine function. To get the second derivative, just do mx.grad(mx.grad())
. You can compose any function transform in any order to any depth.
Automatic Differentiation
Autodiff works on functions, not on implicit graphs. This is a key difference from PyTorch. In PyTorch, autodiff works on implicit graphs.
By default, the gradient is computed w/r/t the first argument. But we can specify the argument.
def loss_fn(w, x, y): return mx.mean(mx.square(w * x - y)) w = mx.array(1.0) x = mx.array([0.5, -0.5]) y = mx.array([1.5, -1.5]) # Computes the gradient of loss_fn with respect to w: grad_fn = mx.grad(loss_fn) dloss_dw = grad_fn(w, x, y) # Prints array(-1, dtype=float32) dloss_dw
array(-1, dtype=float32)
# To get the gradient with respect to x we can do: grad_fn = mx.grad(loss_fn, argnums=1) dloss_dx = grad_fn(w, x, y) # Prints array([-1, 1], dtype=float32) dloss_dx
array([-1, 1], dtype=float32)
The value_and_grad
function provides an efficient way to get the value and the gradient e.g. of the loss.
# Computes the gradient of loss_fn with respect to w: loss_and_grad_fn = mx.value_and_grad(loss_fn) loss, dloss_dw = loss_and_grad_fn(w, x, y) # Prints array(1, dtype=float32) print(loss) # Prints array(-1, dtype=float32) print(dloss_dw)
array(1, dtype=float32) array(-1, dtype=float32)
You can use stop_gradient()
to stop gradients from propagating through a part of the function.
Automatic Vectorization
vmap()
automatically vectorizes complex functions.
# Vectorize over the second dimension of x and the # first dimension of y vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
None
in_axes
specifies which dimensions of the input to vectorize over. out_axes
specifies where they should be in the output.
Compilation
MLX has a compile
function for compiling computational graphs. What does compilation mean in this context? Compilation makes smaller graphs by merging common work and fusing common operations.
The first time you call a compiled function, MLX builds and optimizes the compute graph and generates and compiles the code. This can be slow, but the resulting compiled function is cached, so subsequent calls do not initiate a new compilation.
What causes a function to be recompiled?
- changing shape or number of dimensions
- changing the type of any inputs
- changing the number of inputs
Don't compile functions that are created and destroyed frequently.
Debugging can be tricky. When a compiled function is first called, it is traced with placeholder inputs, so it will crash if there's a print statement. For debugging purposes, disable compilation with disable_compile
or setting the MLX_DISABLE_COMPILE
flag.
Compiled functions should be pure. They should not have side effects. Review this section for functions that update some saved state.
Streams
All operations take an optional stream
keyword specifying which Stream the operation should run on. This is for specifying the device to run on.