Home  >  Article  >  Technology peripherals  >  Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

WBOY
WBOYOriginal
2024-07-17 16:08:17439browse

The performance of large models has been improved from 125M to 1.3B.


Unbelievable, this finally happened.

A new large language model (LLM) architecture is expected to replace Transformer, which has been popular in the AI ​​field so far, and its performance is better than Mamba. On Monday, a paper on Test-Time Training (TTT) became a hot topic in the artificial intelligence community.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Paper link: https://arxiv.org/abs/2407.04620

The authors of this study are from Stanford University, University of California, Berkeley, University of California, San Diego, and Meta. They designed a new architecture, TTT, that replaced the hidden state of RNN with a machine learning model. The model compresses context through actual gradient descent of input tokens.

Karan Dalal, one of the authors of the study, said he believes this will fundamentally change the language model approach.
Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state
In machine learning models, the TTT layer directly replaces Attention and unlocks the linear complexity architecture through expressive memory, allowing us to train LLM with millions (sometimes billions) of tokens in context .

The author conducted a series of comparisons on large models with parameter sizes from 125M to 1.3B and found that both TTT-Linear and TTT-MLP can match or defeat the most powerful Transformers and Mamba architecture methods.

As a new information compression and model memory mechanism, the TTT layer can simply and directly replace the self-attention layer in Transformer.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Compared with Mamba, TTT-Linear has lower perplexity, fewer FLOPs (left), and better utilization of long contexts (right):

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

This is not only linear in theory complexity, and the actual running time is also faster.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

  • After the paper went online, the author made the code and jax public for people to train and test: https://github.com/test-time-training/ttt-lm-jax
  • Also PyTorch inference code: https://github.com/test-time-training/ttt-lm-pytorch

Method introduction

The challenge of long context is intrinsic to the nature of RNN layers ’s: Unlike the self-attention mechanism, the RNN layer must compress the context into a fixed-size hidden state, and the update rules need to discover the underlying structure and relationships between thousands or even millions of tokens.

The research team first observed that self-supervised learning can compress large training sets into weights for models such as LLM, and LLM models often exhibit a deep understanding of the semantic connections between their training data.

Inspired by this observation, the research team designed a new class of sequence modeling layers, where the hidden state is a model and the update rule is a step of self-supervised learning. Since the process of updating the hidden state on the test sequence is equivalent to training the model at test time, the research team calls this new layer the Test-Time Training (TTT) layer.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

The research team introduces two simple examples: TTT-Linear and TTT-MLP, where the hidden states are linear models and two-layer MLP respectively. TTT layers can be integrated into any network architecture and optimized end-to-end, similar to RNN layers and self-attention.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

In order to make the TTT layer more efficient, the study adopted some tricks to improve the TTT layer:

First, similar to taking a gradient step for mini-batch sequences during regular training to obtain better parallelism, the study Use small batches of tokens during TTT.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Secondly, the study develops a dual form for operations within each TTT mini-batch to better utilize modern GPUs and TPUs. The output of the dual form is equivalent to the simple implementation, but training is more than 5 times faster. As shown in Figure 3, TTT-Linear is faster than Transformer and comparable to Mamba in the 8k context.

The research team believes that all sequence modeling layers can be viewed as storing historical context into a hidden state, as shown in Figure 4.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

For example, RNN layers such as LSTM, RWKV, and Mamba layers compress context into a fixed-size state across time. This compression has two consequences: On the one hand, mapping input tokens x_t to output tokens z_t is efficient because the update rules and output rules for each token require constant time. On the other hand, the performance of an RNN layer in long contexts is limited by the expressiveness of its hidden states s_t.

Self-attention can also be viewed from the above perspective, except that its hidden state (often called the Key-Value cache) is a list that grows linearly with t. Its update rule simply appends the current KV tuple to this list, while its output rule scans all tuples before t to form the attention matrix. The hidden state explicitly stores all historical context without compression, which makes self-attention more expressive than RNN layers for long contexts. However, the time required to scan this linearly growing hidden state also grows linearly. To keep long contexts efficient and expressive, researchers need a better compression heuristic. Specifically, thousands or possibly millions of tokens need to be compressed into a hidden state that effectively captures their underlying structure and relationships. This may sound difficult, but many people are actually very familiar with this heuristic.

Backbone architecture. The cleanest way to integrate any RNN layer into a larger architecture is to directly replace the self-attention in Transformer, here called the backbone. However, existing RNNs (such as Mamba and Griffin) use different backbone layers from Transformer. Most notably, their backbone layers contain temporal convolutions before the RNN layer, which may help collect local information across time. After experimenting with the Mamba backbone, the researchers found that it could also improve the perplexity of the TTT layer, so it was included in the proposed method, as shown in Figure 16.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Experimental results

In the experiment, the researchers compared TTT-Linear and TTT-MLP with Transformer and Mamba, two baselines.

Short text

From Figure 11 we can draw the following conclusions:

  • 2k context, the performance of TTT-Linear (M), Mamba and Transformer are comparable because of the lines Mostly overlap. TTT-MLP (M) performs slightly worse with larger FLOP budget. Although TTT-MLP has better perplexity than TTT-Linear at various model sizes, the additional cost of FLOPs offsets this advantage.
  • For the 8k context, both TTT-Linear (M) and TTT-MLP (M) perform significantly better than Mamba, which is quite different from the observation in the 2k context. Even TTT-MLP (T) using the Transformer backbone network is slightly better than Mamba at around 1.3B. A significant phenomenon is that as the context length increases, the advantages of the TTT layer over the Mamba layer also expand.
  • With the context length reaching 8k, Transformer still performs well in perplexity under each model size, but it is no longer competitive due to the cost of FLOPs.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

The results above show the impact of switching the TTT layer from the Mamba backbone network to the Transformer backbone network. The researchers hypothesized that temporal convolutions in the Mamba backbone network are more helpful when the hidden states of the sequence modeling layer are less expressive. Linear models are less expressive than MLPs and therefore benefit more from convolutions.

Long Text: Books

To evaluate the ability of long contexts, we used Books3, a popular subset of Pile, to experiment with context lengths from 1k to 32k in 2x increments. The training method here is the same as Pile, and all experiments for the TTT layer are completed in one training run. From the subset of results in Figure 12, they made the following observations:

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

In the context of Books 2k, all the observations for Pile 2k still hold, except that Mamba now performs slightly better than TTT-Linear ( And their lines roughly overlap in Pile 2k).

In the 32k context, both TTT-Linear (M) and TTT-MLP (M) perform better than Mamba, similar to the observations for Pile 8k. Even TTT-MLP (T) with Transformer backbone performs slightly better than Mamba in 32k context.

TTT-MLP (T) is only slightly worse than TTT-MLP (M) at 1.3B scale. As mentioned above, it is difficult to derive an empirical scaling law due to the lack of a clear linear fit. However, the strong trend in TTT-MLP (T) suggests that the Transformer backbone may be better suited for larger models and longer contexts, beyond the scope of our evaluation.

Clock Time

The training and inference of LLM can be decomposed into forward, backward and generation. Cue word processing during inference (also called pre-population) is the same as the forward operation during training, except that the backward operation does not require the storage of intermediate activation values.

Since both forward (during training and inference) and backward can be processed in parallel, the dual form is used here. Generating new tokens (also called decoding) is sequential in nature, so the raw form is used here.

The researcher mentioned that due to resource limitations, the experiment in this article was written in JAX and ran on TPU. On a v5e-256 TPU pod, the Transformer baseline takes 0.30 seconds per iteration to train with 2k contexts, while TTT-Linear takes 0.27 seconds per iteration, which is 10% faster without any system optimizations. Since Mamba (implemented with PyTorch, Triton and CUDA) can only run on GPU, in order to make a fair comparison, the researchers conducted preliminary system optimization of this method so that it can run on GPU.

The left side of Figure 15 shows the latency of the forward kernel for each model at a batch size of 16. All models are 1.3B (Mamba is 1.4B). It is worth noting that the Transformer baseline here is much faster than the one in the Mamba paper because vLLM is used here instead of HuggingFace Transformer.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

In addition, the researchers also wrote another GPU kernel for generation and benchmarked its speed with a batch size of 512 on the right side of Figure 15. Another commonly used wall-clock time metric is throughput, which takes into account the potential benefits of using larger batch sizes. For throughput, all of the above observations and ordering between methods still hold.

Lead author

After the TTT study was submitted, one of the authors of the paper, UCSD Assistant Professor Xiaolong Wang, tweeted his congratulations. He said that the research on TTT lasted for a year and a half, but it has actually been five years since the idea of ​​Test Time Training (TTT) was born. Although the original idea and the current results are completely different.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

The three main authors of the TTT paper are from Stanford, UC Berkeley and UCSD respectively.

Among them, Yu Sun is a postdoctoral fellow at Stanford University. He graduated from UC Berkeley EECS with a Ph.D., and his long-term research direction is TTT.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Xinhao Li is a PhD candidate at UCSD. He graduated from the University of Electronic Science and Technology of China.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

Karan Dalal is a PhD candidate at UC Berkeley who co-founded a veterinary telemedicine startup called Otto while in high school.

Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state

The above three people all wrote test-time training in the first line of their personal websites introducing research directions.

For more research details, please refer to the original paper.

The above is the detailed content of Completely change the language model: the new architecture TTT surpasses the Transformer, and the ML model replaces the RNN hidden state. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn