Home > Article > Technology peripherals > New work from the author of Mamba: Distilling Llama3 into a hybrid linear RNN
The key to Transformer’s great success in the field of deep learning is the attention mechanism. The attention mechanism allows Transformer-based models to focus on parts relevant to the input sequence, achieving better context understanding. However, the disadvantage of the attention mechanism is that the computational overhead is high, which increases quadratically with the input size, making it difficult for the Transformer to handle very long texts.
Some time ago, the emergence of Mamba broke this situation, which can achieve linear expansion as the context length increases. With the release of Mamba, these state space models (SSMs) can already match or even surpass Transformer at small to medium scale, while maintaining linear scalability with sequence length, which gives Mamba favorable deployment characteristics.
Simply put, Mamba first introduces a simple but effective selection mechanism, which can re-parameterize SSM according to the input, allowing the model to retain necessary information indefinitely while filtering out irrelevant information. and related data.
Recently, a paper titled "The Mamba in the Llama: Distilling and Accelerating Hybrid Models" proves that by reusing the weights of the attention layer, large transformers can be distilled into large hybrid linear RNNs, just Minimal extra computation while retaining most of its build quality.
The resulting hybrid model, which contains a quarter of the attention layer, achieves comparable performance to the original Transformer in the chat benchmark, and outperforms using data in the chat benchmark and general benchmarks. An open source hybrid Mamba model trained from scratch by trillion tokens. Additionally, the study proposes a hardware-aware speculative decoding algorithm that speeds up inference for Mamba and hybrid models.
Paper address: https://arxiv.org/pdf/2408.15237
The best performing model of this study is from Llama3-8B-Instruct Distilled, it achieved a length-controlled winning rate of 29.61 on AlpacaEval 2 relative to GPT-4, and a winning rate of 7.35 on MT-Bench, surpassing the best instruction-adjusted linear RNN model.
Methods
Knowledge Distillation (KD) is a model compression technique used to transfer knowledge from a large model (teacher model) to a smaller model (student model) model), which aims to train the student network to imitate the behavior of the teacher network. The research aims to distill the Transformer so that its performance is comparable to the original language model.
This study proposes a multi-stage distillation method that combines progressive distillation, supervised fine-tuning and directional preference optimization. Compared with ordinary distillation, this method can achieve better perplexity and downstream evaluation results.
The study assumes that most of the knowledge from the Transformer is retained in the MLP layer transferred from the original model, and focuses on the fine-tuning and alignment steps of the distilled LLM. During this phase, the MLP layer remains frozen and the Mamba layer is trained.
This study believes that there are some natural connections between linear RNN and attention mechanism. The attention formula can be linearized by removing softmax:
But linearizing attention will lead to degradation of model capabilities. To design an efficient distilled linear RNN, this study approaches the original Transformer parameterization as closely as possible while extending the capacity of the linear RNN in an efficient manner. This study does not attempt to have the new model capture the precise original attention function, but instead uses a linearized form as a starting point for distillation.
As shown in Algorithm 1, this study feeds the standard Q, K, V heads from the attention mechanism directly into the Mamba discretization and then applies the resulting linear RNN. This can be thought of as using linear attention for coarse initialization and allows the model to learn richer interactions through extended hidden states.
This study directly replaces the Transformer attention head with a fine-tuned linear RNN layer, keeping the Transformer MLP layer unchanged and not training them. This approach also needs to handle other components, such as grouped query attention that shares keys and values across heads. The research team noted that this architecture, unlike those used in many Mamba systems, allows this initialization to replace any attention blocks with linear RNN blocks.
The research also proposes a new algorithm for linear RNN speculative decoding using hardware-aware multi-step generation.
Algorithm 2 and Figure 2 show the complete algorithm. This approach only keeps an RNN hidden state in the cache for verification and lazily advances it based on the success of the multi-step kernel. Since the distillation model contains transformer layers, this study also extends speculative decoding to an Attention/RNN hybrid architecture. In this setup, the RNN layer performs verification according to Algorithm 2, while the Transformer layer only performs parallel verification.
To verify the effectiveness of this method, the study used Mamba 7B and Mamba 2.8B as target models for speculation. The results are shown in Table 1.
Figure 3 shows the performance characteristics of the multi-step kernel itself.
Acceleration on H100 GPU. The algorithm proposed in this study shows strong performance on Ampere GPU, as shown in Table 1 above. But there are huge challenges on the H100 GPU. This is mainly because GEMM operations are too fast, which makes the overhead caused by caching and recomputing operations more noticeable. Indeed, a simple implementation of the studied algorithm (using multiple different kernel calls) achieved considerable speedup on the 3090 GPU, but no speedup at all on the H100.
Experiments and results
This study uses two LLM chat models for experiments: Zephyr-7B is fine-tuned based on the Mistral 7B model, and Llama- 3 Instruct 8B. For the linear RNN model, this study uses a hybrid version of Mamba and Mamba2 with attention layers of 50%, 25%, 12.5%, and 0% respectively, and calls 0% a pure Mamba model. Mamba2 is an architecture variant of Mamba designed primarily for recent GPU architectures.
Evaluation on the Chat Benchmark
Table 2 shows the performance of the model on the Chat Benchmark. The main model compared is the large Transformer model. The results show:
The distilled hybrid Mamba model (50%) achieves similar scores to the teacher model in the MT benchmark, and is slightly better than the teacher model in the AlpacaEval benchmark in terms of LC win rate and overall win rate. .
The performance of the distilled hybrid Mamba (25% and 12.5%) is slightly worse than the teacher model on the MT benchmark, but even with more parameters in AlpcaaEval it still outperforms some large Transformers.
The accuracy of the distilled pure (0%) Mamba model does drop significantly.
It is worth noting that the distilled hybrid model performs better than Falcon Mamba, which is trained from scratch using more than 5T tokens.
General benchmark evaluation
Zero-sample evaluation. Table 3 shows the zero-shot performance of Mamba and Mamba2 distilled from different teacher models on the LM Eval benchmark. The hybrid Mamba-Llama3 and Mamba2-Llama3 models distilled from Llama-3 Instruct 8B performed better compared to the open source TRI Mamba and Nvidia Mamba models trained from scratch.
Benchmark evaluation. Table 4 shows that the performance of the distilled hybrid model matches the best open source linear RNN model on Open LLM Leaderboard, while outperforming the corresponding open source instruction model in GSM8K and CRUX.
Hybrid Speculative Decoding
For the 50% and 25% distillation models, compared to the non-speculative baseline, this study Achieved over 1.8x speedup on Zephyr-Hybrid.
Experiments also show that the 4-layer draft model trained in this study achieves a higher reception rate, but due to the increase in the size of the draft model, the additional overhead also becomes larger. In subsequent work, this research will focus on scaling down these draft models.
Comparison with other distillation methods: Table 6 (left) compares the perplexity of different model variants. The study performed distillation within an epoch using Ultrachat as a seed prompt and compared perplexity. It turns out that removing more layers makes the situation worse. The study also compared the distillation method to previous baselines and found that the new method showed smaller degradation, while the Distill Hyena model was trained on the WikiText dataset using a much smaller model and showed larger confusion degree of degradation.
Table 6 (right) shows that using SFT or DPO alone does not yield much improvement, while using SFT + DPO yields the best score.
Table 7 compares ablation studies for several different models. Table 7 (left) shows the distillation results using various initializations, and Table 7 (right) shows the smaller gains from progressive distillation and interleaving attention layers with Mamba.
Table 8 compares the performance of hybrid models using two different initialization methods: the results confirm that the initialization of attention weights is crucial.
Table 9 compares the performance of models with and without Mamba blocks. Models with Mamba blocks perform significantly better than models without Mamba blocks. This confirms that adding the Mamba layer is crucial and that the performance improvement is not solely due to the remaining attention mechanism.
Interested readers can read the original text of the paper to learn more about the research content.
The above is the detailed content of New work from the author of Mamba: Distilling Llama3 into a hybrid linear RNN. For more information, please follow other related articles on the PHP Chinese website!