Daniel Liden

Blog / About Me / Photos / LLM Fine Tuning / Notes /

MLX Quickstart

These are my notes on the MLX quick start guide and usage notes. It's a work in progress. Ultimately, I'm interested in learning what MLX will let me do with LLMs on my laptop. I might write something more substantial on that topic in the future. For now, you're probably better off consulting the docs yourself than looking at my notes on them.

What is MLX?

MLX is an array framework from Apple ML Research. Its API follows NumPy. It has higher-level packages that follow PyTorch's API for building more complex models modularly. It features:

  • composable function transformations
  • lazy computation
  • dynamic computational graph construction
  • unified memory model. Arrays live in shared memory, and operations on MLX arrays can be performed on CPU or GPU without the need to transfer data.

Quickstart Guide

First, install with pip install mlx or conda install -c conda-forge mlx.

Next, we'll work through the basic features, as shown in the quickstart guide linked above.

Basics

import mlx.core as mx

a = mx.array([1,2,3,4])
a.dtype
mlx.core.int32

Operations are lazy.

b = mx.array([1.0, 2.0, 3.0, 4.0])
c = a+b

c is not computed until it is explicity called or until we use eval. See Lazy Evaluation in More Detail.

mx.eval(c)
c
array([2, 4, 6, 8], dtype=float32)

MLX has grad like PyTorch.

x = mx.array(0.0)
mx.sin(x)

mx.grad(mx.sin)(x)
array(1, dtype=float32)

Lazy Evaluation in More Detail

When you perform an operation:

  • No computation happens
  • A compute graph is recorded
  • Computation happens once an eval() is performed.

PyTorch uses eager evaluation. Tensorflow uses lazy evaluation. Jax uses lazy eval. Jax and TF have different approaches to when they evaluate. TF/Jax graphs are compiled while MLX graphs are built dynamically.

One LLM-relevant use case: initializing model weights. You might initialize a model with model = Model(). The actual weight loading won't happen until you perform an eval(). Useful if you e.g. subsequently update the model with float16 weights. You don't take the memory hit that you'd get with eager execution, loading the float32 weights.

It enables this pattern:

model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")

When to evaluate

It's a tradeoff between:

  • letting graphs get too large
  • not batching enough to do useful work

There's a lot of flexibility.

Luckily, a wide range of compute graph sizes work pretty well with MLX: anything from a few tens of operations to many thousands of operations per evaluation should be okay.

Example of a good pattern for a training loop:

for batch in dataset:

    # Nothing has been evaluated yet
    loss, grad = value_and_grad_fn(model, batch)

    # Still nothing has been evaluated
    optimizer.update(model, grad)

    # Evaluate the loss and the new parameters which will
    # run the full gradient computation and optimizer update
    mx.eval(loss, model.parameters())

Note: whenever you print an array or convert it to a numpy array, it is evaluated. Saving arrays will also evaluate them.

Using arrays for control flow will trigger an eval.

def fun(x):
    h, y = first_layer(x)
    if y > 0:  # An evaluation is done here!
        z  = second_layer_a(h)
    else:
        z  = second_layer_b(h)
    return z

Unified Memory

You do not need to specify the location of an MLX array in memory. CPU and GPU share memory.

Instead of moving arrays to devices, you specify the device when you run an operation.

a = mx.random.normal((100,))
b = mx.random.normal((100,))
None
mx.add(a, b, stream=mx.cpu)

array([-0.999945, -0.255963, 1.04271, ..., 1.08311, -0.993303, -1.48334], dtype=float32)
mx.add(a, b, stream=mx.gpu)
array([-0.999945, -0.255963, 1.04271, ..., 1.08311, -0.993303, -1.48334], dtype=float32)

The MLX scheduler will manage dependencies to avoid race conditions. In other words, this is fine.

c = mx.add(a, b, stream=mx.cpu)
d = mx.add(a, c, stream=mx.gpu)
None

This can be useful if we e.g. send compute-dense operatios to GPU, smaller overhead-bound operations to cpu like this example.

Indexing arrays

  • Is the same as NumPy in most cases
  • EXCEPT:
    • It does not perform bounds checking. Indexing out of bounds is undefined behavior. Why? Exceptions can't propagate from the GPU.
    • Boolean mask indexing is not supported (yet).

Saving and Loading

Support for numpy, numpy archive, safetensors, gguf.

Function transforms

MLX uses composable function transformations for autodiff, vectorization, graph optimization. Main idea: every transformation returns a function that can be further transformed. Here is an example.

dfdx = mx.grad(mx.sin)
dfdx(mx.array(mx.pi))

array(-1, dtype=float32)

The output of grad on sin is another function: the gradient of the sine function. To get the second derivative, just do mx.grad(mx.grad()). You can compose any function transform in any order to any depth.

Automatic Differentiation

Autodiff works on functions, not on implicit graphs. This is a key difference from PyTorch. In PyTorch, autodiff works on implicit graphs.

By default, the gradient is computed w/r/t the first argument. But we can specify the argument.

def loss_fn(w, x, y):
   return mx.mean(mx.square(w * x - y))

w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])

# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
dloss_dw
array(-1, dtype=float32)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
dloss_dx
array([-1, 1], dtype=float32)

The value_and_grad function provides an efficient way to get the value and the gradient e.g. of the loss.

# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)

# Prints array(1, dtype=float32)
print(loss)

# Prints array(-1, dtype=float32)
print(dloss_dw)
array(1, dtype=float32)
array(-1, dtype=float32)

You can use stop_gradient() to stop gradients from propagating through a part of the function.

Automatic Vectorization

vmap() automatically vectorizes complex functions.

# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
None

in_axes specifies which dimensions of the input to vectorize over. out_axes specifies where they should be in the output.

Compilation

MLX has a compile function for compiling computational graphs. What does compilation mean in this context? Compilation makes smaller graphs by merging common work and fusing common operations.

The first time you call a compiled function, MLX builds and optimizes the compute graph and generates and compiles the code. This can be slow, but the resulting compiled function is cached, so subsequent calls do not initiate a new compilation.

What causes a function to be recompiled?

  • changing shape or number of dimensions
  • changing the type of any inputs
  • changing the number of inputs

Don't compile functions that are created and destroyed frequently.

Debugging can be tricky. When a compiled function is first called, it is traced with placeholder inputs, so it will crash if there's a print statement. For debugging purposes, disable compilation with disable_compile or setting the MLX_DISABLE_COMPILE flag.

Compiled functions should be pure. They should not have side effects. Review this section for functions that update some saved state.

Streams

All operations take an optional stream keyword specifying which Stream the operation should run on. This is for specifying the device to run on.

Date: 2024-04-01 Mon 00:00

Emacs 29.3 (Org mode 9.6.15)