Home  >  Article  >  Technology peripherals  >  New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

WBOY
WBOYforward
2023-11-13 20:29:09724browse

Transformer has been successful in a variety of learning tasks in areas such as natural language processing, computer vision, and time series prediction. Despite their success, these models still face severe scalability limitations. The reason is that the exact computation of the attention layer results in quadratic (in sequence length) running time and memory complexity. This brings fundamental challenges to extending the Transformer model to longer context lengths

The industry has explored various methods to solve the problem of the quadratic temporal attention layer, among which One noteworthy direction is to approximate intermediate matrices in attention layers. Methods to achieve this include approximation via sparse matrices, low-rank matrices, or a combination of both.

However, these methods do not provide end-to-end guarantees for the approximation of the attention output matrix. These methods aim to approximate the individual components of attention faster, but none provide an end-to-end approximation of full dot product attention. These methods also do not support the use of causal masks, which are an important part of modern Transformer architectures. Recent theoretical bounds suggest that in general it is not possible to perform a piecewise approximation of the attention matrix in sub-quadratic time

However, a recent study called KDEFormer shows that , which provides provable approximations in subquadratic time under the assumption that the attention matrix terms are bounded. Theoretically, the runtime of KDEFormer is approximately Probability. However, current KDE algorithms lack practical efficiency, and even in theory there is a gap between KDEFormer's runtime and a theoretically feasible O(n) time algorithm. In the article, the author proves that under the same bounded entry assumption, a near-linear time New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% algorithm is possible. However, their algorithm also involves using polynomial methods to approximate the softmax, which is likely to be impractical. New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%In this article, researchers from Yale University, Google Research and other institutions provide an algorithm that has the best of both worlds, which is both practical and efficient, and can achieve optimal near-linear time ensure. Furthermore, the method supports causal masking, which was not possible in previous work.

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% Please click the following link to view the paper: https://arxiv.org/abs/2310.05869

This article proposes an approximate attention mechanism called "HyperAttention" to address the computational challenges caused by using long contexts in large language models. Recent research shows that in the worst case, quadratic time is necessary unless the entries of the attention matrix are bounded or the stable rank of the matrix is ​​low

Rewrite content As follows: The researchers introduced two parameters to measure: (1) the maximum column norm normalized attention matrix, (2) the proportion of row norms in the non-normalized attention matrix after deleting large entries. They use these fine-grained parameters to reflect the difficulty of the problem. As long as the above parameters are small, even if the matrix has unbounded entries or a large stable rank, the linear time sampling algorithm can be implemented

HyperAttention has the characteristics of modular design and can Easily integrate other fast low-level implementations, especially FlashAttention. Empirically, Super Attention outperforms existing methods when employing the LSH algorithm to identify large entries, and achieves significant speed improvements compared to state-of-the-art solutions such as FlashAttention. Researchers verified the performance of HyperAttention on a variety of context datasets of varying lengths

#For example, HyperAttention made ChatGLM2’s inference time 50% faster on 32k context lengths, while perplexity degree increased from 5.6 to 6.3. HyperAttention is 5x faster on a single attention layer with larger context lengths (e.g. 131k) and causal masks.

Method Overview

Dot product attention involves processing three input matrices: Q (queries), K (key), V (value) , are all of size nxd, where n is the number of tokens in the input sequence and d is the dimension of the underlying representation. The output of this process is as follows:

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

Here, matrix A := exp (QK^T) is defined as the element index of QK^T. D is an n×n diagonal matrix derived from the sum of the rows of A, where New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%. In this case, matrix A is called the "attention matrix" and (D^-1) A is called the "softmax matrix". It is worth noting that directly computing the attention matrix A requires Θ(n²d) operations, while storing it consumes Θ(n²) memory. Therefore, directly computing Att requires Ω(n²d) runtime and Ω(n²) memory.

The researcher's goal is to efficiently approximate the output matrix Att while retaining its spectral characteristics. Their strategy consists of designing a near-linear time efficient estimator for the diagonally scaling matrix D. Furthermore, they quickly approximate the matrix product of the softmax matrix D^-1A by subsampling. More specifically, their goal is to find a sampling matrix New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% with a finite number of rows New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% and a diagonal matrix New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% , thus satisfying the following constraints of the operator specification of the error:

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

The researchers showed that by defining the sampling matrix S based on the row specification of V, it can be efficiently Solve the matrix multiplication part of the attention approximation problem in Equation (1). The more challenging problem is: how to obtain a reliable approximation of the diagonal matrix D. In recent results, Zandieh effectively exploits the fast KDE solver to obtain high-quality approximations of D. We simplified the KDEformer program and demonstrated that uniform sampling is sufficient to achieve the required spectral guarantees without the need for kernel density-based importance sampling. This significant simplification allowed them to develop a practical, provable linear-time algorithm.

Unlike previous research, our method does not require bounded entries or bounded stable ranks. Furthermore, even if the entries or stable ranks in the attention matrix are large, the fine-grained parameters introduced to analyze the time complexity may still be small.

As a result, HyperAttention is significantly faster, with over 50 times faster forward and backward propagation at sequence length n= 131k. The method still achieves a substantial 5x speedup when dealing with causal masks. Furthermore, when the method is applied to a pre-trained LLM (such as chatqlm2-6b-32k) and evaluated on the long context benchmark dataset LongBench, it maintains a performance level close to the original model even without the need for fine-tuning. The researchers also evaluated specific tasks and found that summarization and code completion tasks had a greater impact on approximate attentional layers than problem-solving tasks.

Algorithm

#In order to obtain spectrum guarantee when approximating Att, the first step in this article is to perform 1 ± ε approximation. Subsequently, the matrix product between A and V is approximated by sampling (D^-1) according to the square row ℓ₂-norms of V.

The process of approximating D consists of two steps. First, an algorithm rooted in Hamming's sorting LSH is used to identify the primary entries in the attention matrix, as shown in Definition 1. The second step is to randomly select a small subset of K. This paper will show that, under certain mild assumptions about matrices A and D, this simple method can establish the spectral bounds of the estimated matrices. The researcher's goal is to find an approximate matrix D that is accurate enough to satisfy:

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

The assumption of this article is that the column norm of the softmax matrix exhibits a relatively uniform distribution. More precisely, the researcher assumes that for any i ∈ [n] t there exists some New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% such that New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%.

The first step of the algorithm is to identify large entries in the attention matrix A by hashing keys and queries into uniformly sized buckets using Hamming sort LSH (sortLSH). Algorithm 1 details this process and Figure 1 illustrates it visually.

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

The function of Algorithm 1 is to return a sparse mask that is used to isolate the main entries of the attention matrix. After obtaining this mask, researchers can compute an approximation of the matrix D in Algorithm 2 that satisfies the spectrum guarantee in Equation (2). The algorithm is implemented by combining the attention value corresponding to the mask with a randomly selected set of columns in the attention matrix. The algorithm in this paper can be widely applied and can be used efficiently by using predefined masks to specify the position of the main entries in the attention matrix. The main guarantees of this algorithm are given in Theorem 1

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%


## A subroutine that integrates matrix products between the approximate diagonal New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% and the approximate New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% with the value matrix V . Therefore, the researchers introduced HyperAttention, an efficient algorithm that can approximate the attention mechanism with spectrum guarantee in formula (1) in approximately linear time. Algorithm 3 takes as input a mask MH that defines the position of the dominant entry in the attention matrix. This mask can be generated using the sortLSH algorithm (Algorithm 1) or can be a predefined mask, similar to the approach in [7]. We assume that the large entry mask M^H is sparse by design and its number of non-zero entries is boundedNew approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%.

As shown in Figure 2, the method in this paper is based on an important observation. The masked attention M^C⊙A can be decomposed into three non-zero matrices, each of which is half the size of the original attention matrix. Block A_21 completely below the diagonal is unmasked attention. Therefore, we can approximate its row sum using Algorithm 2.

The two diagonal blocks New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% and New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50% shown in Figure 2 are causal attention, and their size Only half of the original size. To deal with these causal relationships, the researchers used a recursive approach, further dividing them into smaller chunks and repeating the process. The pseudocode for this process is given in Algorithm 4.

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

Experiments and results

The researchers processed long range sequences by extending the existing large language model , and then benchmark the algorithm. All experiments were run on a single 40GB A100 GPU and used FlashAttention 2 for precise attention computation.

In order to keep the original meaning unchanged, the content needs to be rewritten into Chinese, and the original sentence does not need to appear

##Researcher HyperAttention is first evaluated on two pre-trained LLMs, and two models with different architectures that are widely used in practical applications are selected: chatglm2-6b-32k and phi-1.5.

In operation, they patch the final ℓ attention layer by replacing it with HyperAttention, where the number of ℓ can vary from 0 to the total number of all attention layers in each LLM . Note that attention in both models requires a causal mask, and Algorithm 4 is applied recursively until the input sequence length n is less than 4,096. For all sequence lengths, we set the bucket size b and the number of sampled columns m to 256. They evaluated the performance of such monkey patched models in terms of perplexity and acceleration.

At the same time, the researchers used LongBench, a collection of long context benchmark data sets, which contains 6 different tasks, namely single/multi-document question answering, summarization, small sample learning, and synthesis Task and code completion. They selected a subset of the dataset with coding sequence length greater than 32,768 and pruned it if the length exceeded 32,768. Then calculate the perplexity of each model, which is the loss of predicting the next token. To highlight scalability for long sequences, we also calculated the total speedup across all attention layers, whether performed by HyperAttention or FlashAttention.

The results shown in Figure 3 above are as follows. Even if chatglm2-6b-32k has passed the HyperAttention monkey patch, it still shows a reasonable degree of confusion. For example, after replacing layer 20, the perplexity increases by approximately 1 and continues to increase slowly until reaching layer 24. The runtime of the attention layer has been improved by approximately 50%. If all layers are replaced, the perplexity rises to 12 and runs 2.3 times faster. The phi-1.5 model also shows a similar situation, but as the number of HyperAttention increases, the perplexity will increase linearly

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

In addition, the researchers also The performance of monkey patched chatglm2-6b-32k on the LongBench data set was evaluated, and the evaluation scores for respective tasks such as single/multi-document question answering, summarization, small sample learning, synthesis tasks, and code completion were calculated. The evaluation results are shown in Table 1 below

While replacing HyperAttention generally results in a performance penalty, they observed that its impact varies based on the task at hand. For example, summarization and code completion are the most robust relative to other tasks.

New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%

The significant point is that when half of the attention layers (i.e. 14 layers) were patched, the researchers confirmed the performance of most tasks The decline will not exceed 13%. Especially for the summary task, performance remains almost unchanged, indicating that this task is the most robust to partial modifications in the attention mechanism. When n=32k, the calculation speed of the attention layer is increased by 1.5 times.

Single self-attention layer

##The researchers further explored sequence lengths ranging from 4,096 The acceleration of HyperAttention when varying from 131,072 to 131,072. They measured the wall clock time of forward and forward-backward operations when computed using FlashAttention or accelerated by HyperAttention. Wall clock time with and without causal masking was also measured. All inputs Q, K, and V are of the same length, the dimensionality is fixed to d = 64, and the number of attention heads is 12.

They choose the same parameters as before in HyperAttention. As shown in Figure 4, when the causal mask is not applied, the speed of HyperAttention is increased by 54 times, and with the causal mask, the speed is increased by 5.4 times. Although the temporal perplexity of causal masking and non-masking is the same, the actual algorithm of causal masking (Algorithm 1) requires additional operations such as partitioning Q, K and V, merging attention outputs, resulting in an increase in actual runtime . When the sequence length n increases, the acceleration will be higher

The researchers believe that these results are not only applicable to inference, but can also be used to train or fine-tune LLM to adapt to longer sequences, which Opens up new possibilities for the expansion of self-attention

The above is the detailed content of New approximate attention mechanism HyperAttention: friendly to long contexts, speeding up LLM inference by 50%. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete