Home >Technology peripherals >AI >PromptPG: When reinforcement learning meets large-scale language models
Mathematical reasoning is a core ability of human intelligence, but abstract thinking and logical reasoning are still a big challenge for machines. Large-scale pre-trained language models, such as GPT-3 and GPT-4, have made significant progress in text-based mathematical reasoning (such as mathematical word problems). However, it is currently unclear whether these models can handle more complex problems involving heterogeneous information such as tabular data. To fill this gap, researchers from UCLA and the Allen Institute for Artificial Intelligence (AI2) launched Tabular Math Word Problems (TabMWP), a dataset of 38,431 open-domain problems that require both text and Perform mathematical reasoning on tabular data to get the correct answer. Each question in TabMWP is associated with a context that contains an image, text, or table in a structured format.
Researchers evaluated different pre-trained models on TabMWP including Few-shot GPT-3. As existing research has found, Few-shot GPT-3 relies heavily on the selection of in-context examples, which results in its performance being quite unstable when examples are randomly selected. This instability is even more severe when dealing with complex inference problems like TabMWP. In order to solve this problem, the author proposed the PromptPG method, which converts the selection of examples into the contextual bandit problem in reinforcement learning, and uses Policy Gradient to train a policy network to learn to select the optimal in from a small amount of training data. -context example. Experimental results show that their proposed PromptPG method exceeds the optimal baseline (Few-shot CoT GPT-3) by 5.31% in answering questions, and their method significantly reduces the problem compared to randomly selected in-context examples. The variance of predictions improves the stability of this type of method.
The following are two examples from the TabMWP data set. One is a free-text question with numerical answers, and the other is a multi-choice question with text answers. As you can see, each question provides a solution that includes step-by-step reasoning. To solve problems in TabMWP, the system must be capable of both table lookup and multi-step mathematical reasoning. Take the example in the picture below, to answer "how much will she spend (if Tracy buys three kinds of breads)", we need to first find the corresponding prices of the three kinds of bread in the table, and then calculate the cost of buying each kind of bread. costs and sum them to get the final cost.
As shown in the statistics in the table below, the TabMWP data set contains 38,431 tabular math problems. 74.7% of the questions were free-text questions and 25.3% were multiple-choice questions. TabMWP has a total of 28,876 unique questions, 6,153 unique answers, and 35,442 unique solutions, indicating its rich diversity in question distribution. The average length of the questions was 22.1 words and the average length of the answers was 49.5 words, indicating the lexical richness of TabMWP. A distinguishing feature of TabMWP is that each problem is accompanied by a table context, without which the problem cannot be solved. TabMWP has a total of 37,644 different tables, with an average table size of 5.9 rows and 2.2 columns, 12.9 cells, and a maximum of 54 cells. These statistics show that the tables in TabMWP are also rich in diversity.
The TabMWP dataset has two different question types and five different answer types:
Every question in TabMWP has a tabular context, which is represented in three formats: image, semi-structured text and structured. This opens the possibility to develop different types of inference models.
Compared with existing data sets, TabMWP requires both table understanding and mathematical reasoning abilities to answer questions. In addition, TabMWP has a detailed multi-step reasoning process for each question, which has obvious advantages in data set size, table type, question type and answer type. To the best of the knowledge of this paper, TabMWP is the first mathematical reasoning dataset in the open-domain tabular scenario.
Considering the achievements of large-scale pre-trained models such as GPT-3 in solving mathematical application problems Successfully, the authors first established a benchmark on TabMWP using few-shot GPT-3. They randomly select some contextual examples from the training set as well as test examples to form prompts that prompt GPT-3 to predict answers. However, recent research shows that this kind of few-shot learning based on random selection may perform very unstable on different contextual example selections. Random selection may be even less effective when dealing with complex inference problems like TabMWP, which involve tables of different types and formats.
In order to solve this problem, the author proposed an improved method: Prompt learning through Policy Gradient, learning to select contextual examples from a small amount of training data, called for PromptPG. As shown in Figure 2, the policy network learns to find the best in-context example from the candidate pool (candidate examples), and its optimization goal is to maximize the prediction of a given training example (training example) when interacting with the GPT-3 environment award. The policy network for selecting examples is a BERT language model based on fixed parameters and a single-layer neural network with learnable parameters. After completing optimization learning, PromptPG can dynamically select different optimal examples from candidate examples for different test questions, thereby maximizing the inference performance of GPT-3.
The following is the learning algorithm of PromptPG.
Pre-training and fine-tuning
Table 3 compares the results of PromptPG and different benchmarks on the TabMWP data set. It can be seen that TAPEX performs better than UnifiedQA due to pre-training on tabular data with similar parameter amounts. For both TAPEX and UnifiedQA, increasing the number of parameters in the model can improve the accuracy of predictions. In addition, fine-tuning the model on TabMWP can also greatly improve the accuracy of predictions.
Large-scale language model
GPT-3 without any fine-tuning (Zero-shot GPT- 3), it can achieve accuracy similar to the fine-tuned UnifiedQA and TAPEX models. If the Few-shot GPT-3 model randomly selects two in-context examples as GPT-3 hints, it can further improve by 0.17% compared to Zero-shot GPT-3. By having Few-shot GPT-3 generate multiple intermediate steps before generating the final answer (Few-shot-CoT GPT-3), the researchers were able to obtain an optimal baseline model with an accuracy of 62.92%.
PromptPG
Different from randomly selecting in-context examples, the PromptPG proposed in this article trains a policy network through Policy Gradient to select more appropriate in-context examples, and achieved the highest prediction result (68.23%) on TabMWP. Its average prediction accuracy exceeds the best baseline model (Few-shot-CoT GPT-3) by 5.31%. Notably, PromptPG demonstrates its superiority in prediction accuracy for almost all question types, answer types, and question difficulties. Despite this, PromptPG still has a lot of room for improvement from the human performance of 90.22%.
Ablation experiment
Table 4 shows that all input elements of TabMWP (question text, form information, option information) are all critical to answering the question correctly. Only with all problem elements as input information, Zero-shot GPT-3 achieved its relatively highest average prediction accuracy (59.50%).
Different sample selection
As a comparative experiment, the researchers also Other methods with different sample selections were compared. As shown in Table 5, choosing the same question type or answer type as the test question can help the model find more relevant examples and improve the accuracy of the answer. Choosing the most complex examples does not consistently improve answer accuracy. Fixed selection of the two best examples among the candidate examples can slightly improve accuracy and reduce variance. Selecting the example that is semantically closest to the test problem achieves the closest accuracy to the PromptPG method. Overall, PromptPG fully demonstrated its advantages in improving prediction accuracy and reducing prediction variance.
The following figure shows an example of PromptPG selection and the final prediction result. It can be seen that the PromptPG method can improve the inference performance of Few-shot GPT-3 by selecting examples with similar mathematical abilities to the test questions.
Example of successful prediction
The following shows PromptPG for a free Correct answers to text questions. This question requires adding and dividing eight numbers in a table to find the average.
In the following example, the model is asked to understand a tax report and calculate the salary after tax deductions.
The following shows PromptPG’s correct predictions for multiple-choice questions. The given table has a total of 9 rows and 6 columns. The model successfully locates the target cell in the table and performs multi-step inference to predict the correct answer.
In the following example, the model needs to compare the budget and total costs to verify whether Ariana has enough money.
Example of prediction failure
The following shows PromptPG for free text Misprediction of the problem. The model retrieved the wrong price for rose quartz, thereby miscalculating the total cost of the three items.
In the following example, the question provides an abstract stem-and-leaf table. The model was unable to understand this domain-specific table and lacked advanced logical reasoning capabilities to get the wrong answers.
#The following examples show that existing models do not seem to have the ability to sort numbers.
In the following example, the time exactly consistent with the current time mentioned in the question does not appear in the table, so the model cannot accurately locate the next time. Departure time for one stop.
#In the following example, it is difficult for the model to accurately complete arithmetic operations on a long series of numbers.
The author proposed TabMWP, which is the first mathematical problem solving in tabular context. large-scale data sets. TabMWP contains 38,431 open-domain questions, including two question types and five answer types, with each question marked with a multi-step solution process. The authors used state-of-the-art QA and TableQA methods, conducted comprehensive experiments on TabMWP in pre-trained and fine-tuned settings, and evaluated using the large pre-trained language model GPT-3. The author further proposes a new reinforcement learning method, PromptPG, which uses Policy Gradient learning to select optimal instances from the training data for prompting the GPT-3 model. Experimental results show that PromptPG significantly outperforms existing baselines and reduces performance instability in predictions compared to random selection.
The above is the detailed content of PromptPG: When reinforcement learning meets large-scale language models. For more information, please follow other related articles on the PHP Chinese website!