laitimes

Make Transformer large models train twice as fast! Dr. Stanford solo

author:Quantum Position

Rich color comes from the temple of Wofei

Qubits | Official account QbitAI

The training and inference speed of existing large language models can be a little faster——

How much faster? 2-4 times.

FlashAttention, which is used by various large models, was officially released today in its 2nd generation and open source, and all models of the Transformer architecture can be used to accelerate it.

Make Transformer large models train twice as fast! Dr. Stanford solo

The generation method, released last June, accelerates attention and reduces memory footprint without any approximation.

Now, FlashAttention-2 has upgraded it again, making its core attention operations 2x faster, end-to-end Transformer training 1.3x faster, and achieving 72% model FLOP utilization when training on the NVIDIA A100 (typical models are around 50%).

Make Transformer large models train twice as fast! Dr. Stanford solo

Given that the cost of building a large language model is now tens of millions of dollars, the series of operations of FlashAttention-2 can directly save us millions (dollars)!

Netizens were so shocked that all the dirty words came out (dog head):

Make Transformer large models train twice as fast! Dr. Stanford solo

At present, the project has received 4.4k stars on GitHub.

Make Transformer large models train twice as fast! Dr. Stanford solo

At the same time, we note that one of its titles has completed a Ph.D. from Stanford and joined the big-model startup Together AI.

Concrete implementation

According to reports, the first generation of FlashAttention is an algorithm for reordering attention calculations, which uses classical methods such as tiling to significantly speed up calculations and reduce the memory usage of sequence length from quadratic to linear.

The tiling method refers to loading an input block from HBM (GPU memory) into SRAM (fast cache), then paying attention to the block and updating the output in HBM.

Repeated reads and writes to HBM become the biggest performance bottleneck.

Make Transformer large models train twice as fast! Dr. Stanford solo

It is this method of reducing the amount of memory reads/writes by avoiding writing large intermediate attention matrices to HBM, resulting in a 2-4x clock time speedup.

However, this algorithm still has some inefficiencies, resulting in it still not as fast as the optimized matrix multiplication (GEMM) operation, and eventually only reaches 25-40% of the theoretical maximum FLOPs/s (e.g. up to 124 TFLOPs/s on the A100).

The reason is that the work between different thread blocks and the wrap division on the GPU are not ideal.

Here, FlashAttention-2 has been improved in three ways.

First, on the underlying algorithm, reduce the number of non-matmul (matrix multiplication) FLOPs.

One layer of reason is that matmul is faster due to the fact that modern GPUs have dedicated computing units. For example, the maximum theoretical throughput of FP16/BF16 matmul on A100 is 312 TFLOPs/s, but the theoretical throughput of non-matmul FP32 is only 19.5 TFLOPs/s.

Another reason is the price, after all, each non-matmul FLOP is 16 times more expensive than matmul FLOP. While spending as much time as possible on matmul FLOP can also maintain high throughput.

To do this, the authors rewrote the softmax trick in FlashAttention to reduce the number of rescaling operations without changing the output, as well as bounds checking and causal masking operations.

Second, parallelize when the batch size is smaller for higher occupancy.

The FlashAttention generation parallelizes on batch size and number of attention heads.

Since it uses 1 thread block to process 1 attention head, there are (batch_size * attention heads) thread blocks in total, each of which is scheduled to run on a streaming multiprocessor (SM).

When operating on a processor with 108 SMs like the A100, this scheduling is effective if there are many thread blocks, such as >=80.

In the case of long sequences, that is, when the batch size and the number of headers are small (small), it is necessary to parallelize in the sequence length dimension to better utilize the multiprocessor on the GPU.

This improvement is also a big reason for the significant speed increase in FlashAttention-2.

Finally, improve the work partition.

Within the thread block, we have to determine how to divide the work between the different warps. Usually 4 or 8 WARPs per block are used, but now the authors have improved this to reduce synchronization and communication between different WARPs, thereby reducing shared memory read and write operations.

As shown on the left of the figure below, the FlashAttention generation splits K and V into 4 warps while keeping Q accessible to all warps. The consequence of this is that all warps need to write their intermediate results to shared memory, then synchronize and add the intermediate results, which is very inefficient and slows down forward propagation in FlashAttention.

Make Transformer large models train twice as fast! Dr. Stanford solo

In FlashAttention-2, the authors divide Q into four warps while ensuring that all warps have access to K and V.

After each warp performs matrix multiplication to obtain a slice of Q K^T, it only needs to be multiplied by the shared slice of V to obtain the corresponding output. That is to say, there is no need for communication between warps, so there are much fewer shared memory read and write operations, and the speed is raised.

In addition to these three big improvements, FlashAttention-2 has two small changes:

First, the number of attention heads increased from 128 to 256, which means that models such as GPT-J, CodeGen and CodeGen2, and StableDiffusion 1.x can use FlashAttention-2 for acceleration and memory savings;

The second is to support multi-query attention (MQA) and group query attention (GQA).

Experimental evaluation

The authors measured the runtime on the A100 80GB SXM4 GPU in different configurations (with or without causal mask, 64 or 128 headers).

The results found:

FlashAttention-2 is about 2 times faster than FlashAttention (including the xformers library and other implementations in Triton), which means that we can train a model with 16k context (i.e. double the length of the model context) for the same price as the previous 8k context model.

Compared to the standard attention implementation in PyTorch, FlashAttention-2 is up to 9 times faster.

Make Transformer large models train twice as fast! Dr. Stanford solo

In addition, with FlashAttention-2, we can run up to 335TFLOPs/s by simply running the same implementation on the H100 GPU (taking advantage of new hardware features such as TMA and fourth-generation Tensor Core without special instructions).

Make Transformer large models train twice as fast! Dr. Stanford solo

And when used to train GPT-style models end-to-end, FlashAttention-2 can also achieve speeds of up to 225TFLOPs/s on the A100 (model FLOPs utilization of 72%). This is another 1.3 times faster than FlashAttention, which is already optimized for high enough.

Make Transformer large models train twice as fast! Dr. Stanford solo

One joins a big model startup

The FlashAttention-2 paper shows only one author: Tri Dao. He is also one of the two co-authors of the FlashAttention generation.

Make Transformer large models train twice as fast! Dr. Stanford solo

It is understood that Tri Dao's research direction is the intersection of machine learning and systems, and won the ICML 2022 Outstanding Paper Runner-up Award last year.

He recently received his Ph.D. in computer science from Stanford University, is about to rise to assistant professor at Princeton University, and has announced his joining generative AI startup Together AI, whose primary goal is to build a cloud platform for running, training, and fine-tuning open-source models, as chief scientist.

Make Transformer large models train twice as fast! Dr. Stanford solo

One More Thing

Finally, some netizens found that in addition to FlashAttention-2, there have been a series of similar results recently, including DeepSpeed's ZeRO++ and the University of Massachusetts de ReLoRA.

They were all used to accelerate the pre-training and fine-tuning of large models, and these research results made him feel:

Training large models on low-VRAM low-bandwidth consumer graphics cards in the future doesn't seem to be a dream anymore.
Make Transformer large models train twice as fast! Dr. Stanford solo

What do you think?

Paper Address:

https://tridao.me/publications/flash2/flash2.pdf

Blog post address:

https://princeton-nlp.github.io/flash-atttention-2/

GitHub homepage:

https://github.com/Dao-AILab/flash-attention

Reference Links:

[1]https://twitter.com/tri_dao/status/1680987577913065472?s=20

[2]https://twitter.com/togethercompute/status/1680994294625337344?s=20

[3]https://twitter.com/main_horse/status/1681041183559254017?s=20

— End —

Qubits QbitAI · Headline number signed