Home  >  Article  >  Technology peripherals  >  Compress 26 tokens into 1 new method to save space in the ChatGPT input box

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

PHPz
PHPzforward
2023-05-09 14:10:371368browse

Before entering the text, first consider the prompt of a Transformer language model (LM) like ChatGPT:

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

With millions of users and queries generated every day, ChatGPT uses a self-attention mechanism to repeatedly encode prompts, and its time and memory complexity grow quadratically with the input length. Caching the prompt's transformer activation prevents partial recalculation, but this strategy still incurs significant memory and storage costs as the number of cached prompts increases. At scale, even a small reduction in prompt length may result in computational, memory, and storage savings while also allowing the user to fit more content into the LM's limited context window.

So. How to reduce the cost of prompt? A typical approach is to fine-tune or distill the model so that it behaves similarly to the original model without prompts, perhaps using parameter-efficient adaptive methods. However, a fundamental drawback of this approach is that the model needs to be retrained each time for a new prompt (shown in the middle of Figure 1 below).

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

In this article, researchers from Stanford University proposed the gisting model (bottom of Figure 1 above ), which compresses any prompt into a set of smaller virtual "Gist" tokens, similar to prefix fine-tuning. However, prefix fine-tuning requires learning prefix for each task through gradient descent, while Gisting uses a meta-learning method to predict Gist prefix only through prompts without learning prefix for each task. This amortizes the cost of per-task prefix learning, allowing generalization to unknown instructions without additional training.

In addition, since the "Gist" token is much shorter than the full prompt, Gisting allows the prompt to be compressed, cached, and reused to improve computational efficiency.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

Paper address: https://arxiv.org/pdf/2304.08467 v1.pdf

The researcher proposed a very simple method to learn the gist model that instructions follow: simply fine-tune the instructions, insert the gish token after the prompt, and modify The after attention mask prevents the token after the gist token from referring to the token before the gist token. This allows the model to learn prompt compression and instruction following simultaneously without additional training cost.

On decodr-only (LLaMA-7B) and encoder-decoder (FLAN-T5-XXL) LM, gisting achieves up to 26x instant compression while maintaining the same performance as the original Model similar output quality. This results in a 40% reduction in FLOPs during inference, a 4.2% latency acceleration, and significantly reduced storage costs compared to traditional prompt caching methods.

Gisting

The researchers first describe gisting in the context of instruction fine-tuning. For the instruction following dataset

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

, t represents the task encoded in a natural language prompt (e.g. translate this to French), x represents the (optional) input of the task (e.g. The cat), and y represents Desired output (e.g. Le chat). The purpose of instruction fine-tuning is to learn the distribution pLM(y | t,x) by concatenating t and x and then letting a usually pre-trained language model autoregressively predict y. During inference, a new task t and input x can be used for prompting and decoding from the model to obtain prediction results.

However, this pattern of connecting t and x has disadvantages: Transformer-based LM has a limited context window, which is limited by architecture or computing power. The latter is particularly difficult to solve because self-attention scales quadratically with input length. Therefore, very long prompts, especially those that are reused repeatedly, are computationally inefficient. What options are available to reduce the cost of prompt?

A simple approach is to perform LM fine-tuning for a specific task t, i.e. given a dataset containing input/output examples only under task t

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

, one can learn a Specifically

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

, it's faster because there's no need to think about t.

Even better, parameter-efficient fine-tuning methods such as prefix/prompt fine-tuning or adapter can achieve the same goal at a much lower cost than full-scale fine-tuning. However, a problem remains: at least a portion of the model weights for each task must be stored, and more importantly, for each task t, the corresponding input/output pair dataset D^t must be collected and the model retrained.

Gisting is a different approach that amortizes two costs: (1) the inference time cost of conditionalizing p_LM on t, (2) learning for each t Training time cost of new p^t_LM. The idea is to learn a compressed version of t G (t) during fine-tuning, such that inference from p_G (y | G (t),x) is faster than from p_LM (y|t,x).

In LM terminology, G (t) will be a set of "virtual" Gist tokens, which are fewer in number than the tokens in t, but will still cause similar problems in LM Behavior. Transformer activations (e.g. key and value matrices) on G (t) can then be cached and reused to improve computational efficiency. Importantly, the researchers hope that G can generalize to unseen tasks: given a new task t, the corresponding Gist activation G(t) can be predicted and used without any additional training.

Learning Gisting through masks

The above describes the general framework of Gisting, and next we will explore an extremely simple method of learning such a model: using LM itself Used as Gist predictor G. This not only leverages pre-existing knowledge in the LM, but also allows learning gisting by simply performing standard instruction fine-tuning and modifying the Transformer attention mask to enhance prompt compression. This means that Gisting does not incur additional training costs and only needs to be fine-tuned based on standard instructions!

Specifically, add a special gist token to the model vocabulary and embedding matrix, similar to the sentence beginning/end tokens common in such models. Then for a given (task, input) tuple (t, x), concatenate t and x together using a set of k consecutive gist tokens in (t, g_1, . . . , g_k, x), e.g.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

. This sequence is fed into the model, with the restriction that input tokens following the gist token cannot reference the previous prompt token (but they can reference the gist token). This forces the model to compress the information in prompt into gist tokens, since input x (output y) cannot process prompt t.

Figure 2 below shows the required changes. For decoder-only LMs such as GPT-3 or LLaMA, which typically employ autoregressive causal attention masks, one only needs to mask out the lower left corner of the triangle shown in Figure 2a. For an encoder-decoder LM with a bidirectional encoder and an autoregressive decoder, two modifications are required (shown in Figure 2b).

First, in the encoder, which is usually unmasked, prevent input token x with reference to prompt token t. But it is also necessary to prevent prompt t and gist token g_i from referring to input token x, otherwise the encoder will learn different gist representations depending on the input. Finally the decoder operates normally except during cross-attention periods when the decoder needs to be prevented from referring to prompt token t.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

Experimental results

For different numbers of gist tokens, LLaMA- The ROUGE-L and ChatGPT evaluation results of 7B and FLAN-T5-XXL are shown in Figure 3 below.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

Models are generally insensitive to the number k of gist tokens: compressing prompts into a single token does not result in significant performance degradation. In fact, in some cases too many gist tokens hurt performance (e.g. LLaMA-7B, 10 gist tokens), possibly because the increased capacity overfits the training distribution. Therefore, the researchers give the specific values ​​of the single-token model in Table 1 below, and use a single gist model in the remaining experiments.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

On the instructions seen, the gist model obtained almost the same results as its corresponding positive control model With the same ROUGE and ChatGPT performance, the winning rates on LLaMA-7B FLANT5-XXL are 48.6% and 50.8% respectively. What researchers are most interested in here is their generalization ability on unseen tasks, which needs to be measured through two other datasets.

In the unseen prompts in the Alpaca training data set, we can see that the gist model has strong generalization ability on unseen prompts: compared with the control group, there are 49.7% (LLaMA) and 46.2% (FLAN-T5) winning rates. On the most challenging OOD Human split, the gist model’s winning rate drops slightly, to 45.8% (LLaMA) and 42.5% (FLANT5).

The purpose of this article is to have a gist model closely mimic the functionality of the original model, so one might ask when exactly a gist model is indistinguishable from a control group. Figure 4 below illustrates how often this happens: for seen tasks (but unseen input), the gist model is on par with the control group almost half the time. For unseen tasks, this number drops to 20-25%. For the OOD Human task, this number drops back to 10%. Regardless, the quality of the gist model output is very high.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

Overall, these results show that the gist model can reliably compress prompts, This can be done even on certain prompts outside the training distribution, especially decoder-only causal LMs like LLaMA. Encoder-decoder models such as FLAN-T5 perform slightly worse. One possible reason is that the gist mask suppresses the bidirectional attention flow in the encoder, which is more challenging than just masking a part of the history in the autoregressive decoder. Further work is needed to investigate this hypothesis in the future.

Computing, memory and storage efficiency

Finally, back to one of the core motivations of this work: What kind of efficiency improvements can gisting bring?

Table 2 below shows the results of a single forward pass of the model using the PyTorch 2.0 analyzer (i.e., one step of autoregressive decoding using a single input token), and the Human eval The 252 instructions in the split are averaged. gist caching significantly improves efficiency compared to unoptimized models. Savings in FLOPs of 40% and clock time reductions of 4-7% were achieved for both models.

Compress 26 tokens into 1 new method to save space in the ChatGPT input box

However, more importantly, compared to the instruction cache, the gist cache has a latency Key advantages beyond: compressing 26 tokens into 1 can free up more space in the input context window, which is limited by absolute position embedding or GPU VRAM. Especially for LLaMA-7B, each token in the KV cache requires 1.05MB of storage space. Although the KV cache contributes little relative to the total memory required for LLaMA-7B inference at the tested prompt lengths, an increasingly common scenario is for developers to cache many prompts across a large number of users, and the storage cost can quickly increase Increase. With the same storage space, the gist cache can handle 26 times more prompts than the full instruction cache.

The above is the detailed content of Compress 26 tokens into 1 new method to save space in the ChatGPT input box. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete