Home >Technology peripherals >AI >Apple lets large models learn to be lazy: spit out the first token faster and maintain accuracy
Be lazy to work better.
Llama 3.1 has just been released, have you tried it yet? Even if you have a recent top-of-the-line PC, you may still experience significant lag running the smallest 8B version. In order to improve the reasoning efficiency of the model, researchers have come up with a variety of methods, but many of them will cause the model to sacrifice some accuracy.
Recently, a research team from Apple and Meta AI proposed a new method that can increase the inference speed of Llama 2 pre-filling stage to more than 2 times while ensuring that the accuracy does not drop significantly. This may be able to Provide some inspiration for the acceleration of Llama 3.1. They call this approach LazyLLM, or Lazy Large Language Model.
Paper title: LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference
Paper address: https://arxiv.org/abs/2407.14057
So how do they make LLM lazy? What about? To understand their approach, we first need to know what the standard prompt-based LLM inference process looks like. Briefly, the process is divided into two stages: pre-filling and decoding, as shown in Figure 1.
In the pre-population stage, the model calculates and saves the KV cache of each token in the prompt and predicts the first token. We call the time spent in the pre-population phase "Time to First Token (TTFT)".
The pre-filling stage is followed by the decoding stage. At this stage, the model again uses the cached KV to iteratively decode the next token until the stopping criterion is met.
In the pre-population stage, all Transformer layers will use all tokens in prompt. TTFT can be slow when the prompt is long because current state-of-the-art Transformer-based LLMs are both deep and wide, and the cost of computing attention grows quadratically with the number of tokens in the prompt. For example, Llama 2 (version 7B) stacks 32 layers of Transformers and has a model dimension of 4096. In this case, TTFT requires 21 times the walltime of each subsequent decoding step, which accounts for approximately 23% of the total generation time on the LongBench benchmark.
Therefore, to make LLM inference efficient, optimizing TTFT is a very critical step.
Although LLM inference optimization is an active research area, many methods focus on improving the inference speed of the decoding stage. Researchers have paid little attention to improvements in TTFT. Some compression-based research results can implicitly improve TTFT by reducing the size of the LLM.
Another research direction is to improve TTFT under the static Transformer architecture. For this research direction, a question naturally arises: Are all prompt tokens essential when generating the first token?
Figure 2 shows the results of LLM analysis on the LongBench benchmark.
It can be seen that for the first generated token, the attention score of the input token is very sparse, which shows that many tokens in the input prompt are redundant, and even if they are removed, they will not affect the prediction of the next token. This observation was the basis for the team’s proposed LazyLLM.
The advantages of LazyLLM include wide application range, no need for training, and good results. Figure 3 compares standard LLM with LazyLLM.
LazyLLM
Figure 4 shows the overall framework of LazyLLM.
Starting from the complete context, LazyLLM will gradually prune tokens, thereby gradually reducing the number of calculations used to obtain the final model. Note that LazyLLM allows the model to select different subsets of tokens at different generation steps, even if some of them may have been pruned in previous steps. Compared with static pruning (which prunes all tokens at once), dynamic pruning optimizes the prediction of the next token at each generation step, which helps maintain the performance of the model.
Progressive token pruning
Some previous studies have successfully used token pruning to optimize LLM inference. However, these methods require the accumulation of complete attention maps predicting the first few tokens in order to analyze the importance of prompt tokens before pruning begins. Therefore, they are not suitable for reducing TTFT because they still need to calculate all KV caches during the pre-fill phase.
In comparison, LazyLLM is "very lazy" and will only calculate tokens that are important for predicting the next token starting from the first iteration of inference (pre-filling step).
In the first round of iteration, a key problem is to determine the importance of each token. Inspired by previous research showing that token hidden states evolve as they pass through Transformer layers, the team's solution is to use layer-by-layer token pruning at each generation step. Specifically, they use the attention map of each layer to determine the importance of the input token to the token to be predicted.
After calculating the token’s confidence score, another difficult problem is to determine the threshold for pruning the token.
Specifically, for different layers and different tasks, this threshold may change as the attention score changes. The team’s solution is to use the top-k percentile selection strategy. Specifically, if the confidence score of a token is less than the kth percentile of the input tokens, it is pruned. Once a token is pruned, it no longer participates in the calculation of all subsequent layers.
In other words, the tokens used by subsequent layers are a subset of the tokens used by previous layers.
Later experiments show that when the position of the pruning layer and the number of pruned tokens are different, the performance will also change. Specifically, for the same Transformer layer, as more and more tokens are removed by pruning, the performance of the model will gradually decrease.
They also found that compared to pruning in early layers, better performance will be obtained when pruning in later layers, which shows that later layers are less sensitive to token pruning. To better balance speed and accuracy, the team used progressive pruning as shown in Figure 4, retaining more tokens in early layers and then gradually reducing the number of tokens as they flow to later layers.
Aux Cache (auxiliary cache)
There is no KV cache in the pre-filling stage, and each token is expressed in a hidden state. Therefore, progressive token pruning can be achieved by removing the hidden state of pruned tokens. However, extending progressive token pruning to subsequent decoding steps is not straightforward. The reason is that each decoding step computes the attention using the KV cache calculated in the pre-fill stage. Since LazyLLM performs progressive token pruning in the pre-population stage, the KV of a token that is pruned at a certain level will not appear in the KV cache of the next level.
As a reminder, the LazyLLM framework allows each generation step to pick a different subset of tokens from the complete input token sequence at each step, regardless of whether they have been pruned in previous steps. For example, in the subsequent decoding step, pruned tokens that do not exist in the KV cache may be reselected for attention calculation. In this case, the model cannot retrieve the KV cache for these tokens.
An intuitive solution is to pass these tokens through the starting point of the Transformer. However, this results in double counting of the same token and ultimately slows down the overall generation speed.
To solve this problem, the team introduced another cache in addition to the original KV cache: Aux Cache (auxiliary cache).
If the KV of pruned tokens (T4 and T7 in Figure 4) do not appear in the KV cache of subsequent layers, their hidden states will be saved by the Aux Cache for retrieval in subsequent iterations.
As shown in Figure 4, at each decoding step, each Transformer layer first retrieves the KV cache of past tokens (if it exists). For those tokens that are not in the KV cache, their hidden state is retrieved directly from the Aux Cache of the previous layer without having to go through the previous layer again. Aux Cache ensures that each token is calculated at most once in each Transformer layer, and also ensures that LazyLLM is faster than standard LLM at its slowest.
Experiments
The team tested this "lazy" new approach on two large language models: Llama 2 7B and XGen 7B. The standard LLM used for comparison is the same publicly released pre-trained checkpoint model without any additional training.
The experimental benchmark is LongBench, a multi-task benchmark for long content understanding. The LongBench benchmark contains 16 datasets covering 6 tasks, including single-document Q&A, multi-document Q&A, summarization, few-shot learning, synthesis tasks, and code completion.
The evaluation metric is the effectiveness and efficiency of each method in terms of TTFT acceleration versus accuracy trade-off.
Results
Table 1 gives the TTFT speedup and accuracy results for LazyLLM, standard LLM, and other baseline methods.
In this table, baseline refers to standard LLM inference. Random token drop refers to performing random pruning on tokens. Static token pruning refers to performing one-time pruning on the input token based on the attention method of the previous Transformer layers in the pre-filling stage. Prompt Compression is the prompt compression method, which uses LLM to remove redundancy in the input context.
As can be seen from Table 1, LazyLLM is comprehensively superior in TTFT acceleration, while the decrease in accuracy is basically negligible. It should be noted that using LLM to compress prompts is computationally intensive. Therefore, even though Prompt Compression makes inference faster, its actual TTFT is longer than standard LLM.
Impact on the overall generation speed
To evaluate the impact of the new method on the overall generation speed, the team analyzed the percentage of prompt tokens used in the calculation and the generation acceleration, see Table 2.
It can be seen that the proportion of tokens used in LazyLLM calculations is always less than 100%, which shows that LazyLLM has not used up all the tokens in the prompt at the end of the generation, but theoretically the model can use all token. This can provide additional acceleration to the overall generation process for different tasks.
Discard rate of different layers
The team also analyzed the impact of the position of the pruning layer and the number of pruned tokens. The results are shown in Figure 6.
It can be seen that when pruning is performed at the same Transformer layer, the fewer tokens left, the worse the performance of the model. This is also consistent with our intuitive understanding. In addition, compared to performing pruning in earlier Transformer layers, pruning in later layers will result in better performance, which shows that later layers are less sensitive to token pruning.
Based on these observations, it can be said that the effect of progressive token pruning is proven.
Progressive KV growth
Finally, the team also tried to understand the internals of the model using token pruning logic. Specifically, they want to know the cumulative proportion of prompt tokens that are used and the corresponding proportion that is not used. This "cumulative token usage" can be equivalently defined as the KV cache size at each step. Figure 7 gives these cumulative prompt token usage for each stage of LazyLLM.
This result supports the hypothesis that many tokens will never be selected by the model (even though theoretically the model can use all tokens in the prompt.
Considering that the model can still maintain the accuracy of performing the task, it can Conclusion: The model can effectively discard tokens without affecting the output quality
.The above is the detailed content of Apple lets large models learn to be lazy: spit out the first token faster and maintain accuracy. For more information, please follow other related articles on the PHP Chinese website!