Troubleshooting Flash Attention Installation
I have repeatedly run into issues getting flash-attention working correctly with whatever version of PyTorch and CUDA I happen to be working with. I found a working pattern, at least for the platform I tend to be working on (Databricks). This note is a quick summary.
The Problem
I kept getting an "undefined symbol" error like this when trying to load a model with flash attention (or even just when importing the flash attention library).
Solution
The following approach worked.
- Verify CUDA version; install the right version of Torch.
- Clone the flash-attention library and install (don't just pip install)
So in the case of my most recent project:
%pip install --upgrade torch
was fine because it's compiled for cuda 12.
To install flash-attention
:
%sh git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention pip install . --no-build-isolation
We can then make sure everything works (without needing to take extra time to load a model, for example) like this:
import torch print(torch.__version__) print(torch.version.cuda) import flash_attn print(flash_attn.__version__)
2.2.2+cu121 12.1 2.5.7
What didn't work
I wasn't able to get any variety of pip install flash-attn
working. This was regardless of the no build isolation flag; specific versions; etc.