Daniel Liden

Blog / About Me / Photos / Notebooks / Notes /

Intro to QLoRA

I have a basic understanding of what QLoRA is, but given its popularity and apparent success, I am not nearly familiar enough with it. These are my notes on the Hugging Face blog post about QLoRA and quantization. Later, I will also make a note with some examples.

Link: Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA

1. Intro to Quantization

The post recommends reading the intro to this post first for some background on quantization. Key points:

  • Models are getting bigger, making them harder and more expensive to run.
  • Quantization is one way to make them smaller.
  • int8 inference doesn't appear to significantly degrade model performance.

Common data types used in ML

The size of a model is determined by the number of its parameters and their precision. Common types include:

  • Float32 (FP32), which reserves 8 bits ffor the exponent, 23 bits for the mantissa, and 1 bit for the sign. Most hardware supports FP32 operations. This is called full precision (4 bytes)
  • Float16 (fp16) reserves 5 bits for the exponent, 10 for the mantissa, 1 for the sign. Thus it can represent a far smaller range of numbers, resulting in a risk of overflowing or underflowing. The largest representable number in FP16 is 64k.
import numpy as np

# Define two numbers within the FP16 range
a = np.float16(10000)
b = np.float16(10000)

# Perform multiplication
result = a * b

# Check the result
result
/tmp/ipykernel_12/2565578634.py:8: RuntimeWarning: overflow encountered in scalar multiply
  result = a * b
  • The bfloat16 type was created to avoid these constraints. Bfloat16 reserves 8 bits for the exponent, 7 for the mantissa, and 1 for the sign. So we retain the range of FP32, but we lose 3 bits of precision w/r/t fp16. So there's no problem with big numbers, but we lose precision. fp16 and bf16 are both referred to as half precision (2 bytes).
  • int8 is an 8-bit representation that can store 28 different values (between 0 and 255 for unsigned integers; between -128 and 127 for signed).

Different types in training and inference

"Ideally" we should use FP32 for training and inference. But it is two times slower than half precision. So a mixed-precision approach is preferred. Weights are held in FP32 as a precise "main weights" reference. Computation in the forward/backward pass are done in fp16/bf16 to improve training speed. The fp16/bf16 gradients are used to update the fp32 weights.

During inference, there's no real need for the full-precision weights.

What happens when I load a model in bf16 and then train it? Is that still mixed precision? Or is bf16 the reference copy in that case?

The FP8 Format

The FP8 format has two formats: E4M3 and E5M2 (where E=exponent, M=mantissa). E4M3 (higher precision, smaller range) is best suited for the forward pass, while E5M2 (lower precision, higher range) is better suited for the backward pass.

QLoRA Paper Overview

QLoRA reduces the memory usage of LLM fine-tuning without appreciable loss of performance compared to half precision fine-tuning.

The model to be fine-tuned is loaded in 4-bit precision and then the weights are frozen. A small number of trainable parameters are added in the form of low-rank adapters. The LoRA adapters are the only parameters updated during training.

QLoRA typically stores the base model weights in 4-bit NormalFloat, and uses bf16 for computation. Weights are dequantized from the storage type to perform the forward and backward passes gradients are only computed for the LoRA parameters, which use bf16. Weights are decompressed on an as-needed basis so memory usage stays low.

Key summary (from the original paper): "QLoRA backpropagates gradients through a frozen, 4-bit quantized pretrained language model into Low Rank Adapters~(LoRA)."

Date: 2024-04-05 Fri 00:00

Emacs 29.3 (Org mode 9.6.15)