首頁  >  文章  >  科技週邊  >  PromptPG:當強化學習遇見大規模語言模型

PromptPG:當強化學習遇見大規模語言模型

王林
王林轉載
2023-04-07 14:51:031195瀏覽

数学推理是人类智能的一项核心能力,但对于机器来说,抽象思维和逻辑推理仍然是一个很大的挑战。大规模预训练语言模型,如 GPT-3 和 GPT-4,在文本形式的数学推理(如数学应用题)上已经取得了显著的进展。然而,目前我们还不清楚这些模型能否处理涉及到异构信息(如表格数据)的更复杂的问题。为了填补这一空白,来自 UCLA 和艾伦人工智能研究院(AI2) 的研究人员推出了 Tabular Math Word Problems (TabMWP) ,这是一个包含了 38,431 个开放领域问题的数据集,需要同时在文本和表格数据上进行数学推理得到正确答案。TabMWP 中的每个问题都与一个上下文相关联,这个上下文包含PromptPG:當強化學習遇見大規模語言模型、文本或结构化格式的表格。

研究人员在 TabMWP 上评估了包括 Few-shot GPT-3 等不同的预训练模型。正如已有的研究发现,Few-shot GPT-3 很依赖 in-context 示例的选择,这导致其在随机选择示例的情况下性能相当不稳定。这种不稳定在处理像 TabMWP 这样复杂的推理问题时表现得更加严重。为了解决这一问题,作者提出了 PromptPG 方法,这种方法将示例的选择转化成强化学习中的 contextual bandit 问题,并且利用 Policy Gradient 训练一个策略网络来学习从少量的训练数据中选择最优的 in-context 示例。实验结果表明,他们提出的 PromptPG 方法在回答问题的准确性上超过最优基准(Few-shot CoT GPT-3)5.31%,并且相对于随机选择的 in-context examples,他们的方法显著降低了预测的方差,提升了这类方法的稳定性。

PromptPG:當強化學習遇見大規模語言模型


  • 论文链接:https://arxiv.org/abs/2209.14610
  • 代码链接:https://github.com/lupantech/PromptPG
  • 项目主页:https://promptpg.github.io
  • 数据可视化:https://promptpg.github.io/explore

1、TabMWP 数据集

下面是来自 TabMWP 数据集的两个例子。其中一个是答案为数值类型的自由文本问题(free-text),另一个是答案为文本类型的多项选择题(multi-choice)。可以看到,每个问题都提供了一个包含分步推理的解答。要解决 TabMWP 中的问题,系统必须同时具备查表和多步数学推理的能力。举下图中的例子来说,要回答 “how much will she spend (if Tracy buys three kinds of breads)”,我们需要先在表格中查找出三种面包对应的价格,再计算购买每种面包的费用,并对它们求和已得到最终的费用。

PromptPG:當強化學習遇見大規模語言模型

如下表的统计所示,TabMWP 数据集包含 38,431 个表格数学问题。其中 74.7% 的问题属于自由文本问题,25.3% 的问题属于多选题。TabMWP 共有 28,876 个不同的问题,6,153 个不同的答案和 35,442 个不同的解答,表明其在问题分布方面具有丰富的多样性。这些问题平均长度为 22.1 个单词,解答平均长度为 49.5 个单词,这表明 TabMWP 具有词汇的丰富性。TabMWP 的一个显著特点是,每个问题都附带有一个表格上下文,如果没有表格,问题将无法解决。TabMWP 总共有 37,644 个不同的表格,表格平均有 5.9 行和 2.2 列,12.9 个单元格,最大可达 54 个单元格。这些统计数据表明,TabMWP 中的表格也具有丰富的多样性。

PromptPG:當強化學習遇見大規模語言模型

TabMWP 数据集有两种不同的问题类型以及五种不同的答案类型:

PromptPG:當強化學習遇見大規模語言模型

TabMWP 中的每個問題都有一個表格上下文,它以圖像、半結構化文字和結構化三種格式表示。這為開發不同類型的推理模型提供了可能性。

PromptPG:當強化學習遇見大規模語言模型

比起已有的資料集,TabMWP 同時需要表格理解和數學推理能力來回答問題。另外,TabMWP 每題都有詳細的多步驟推理過程,在資料集大小、表格類型、問題類型和答案類型上有明顯的優勢。據本文所知,TabMWP 是第一個在開放領域表格場景下的數學推理資料集。

PromptPG:當強化學習遇見大規模語言模型

2、PromptPG 方法

#考慮到大規模預訓練模型例如GPT-3 在解決數學應用問題方面取得的成功,作者首先使用Few-shot GPT-3 在TabMWP 上建立了一個基準。他們從訓練集中隨機選擇一些上下文範例以及測試樣本構成提示(prompt),提示 GPT-3 預測答案。然而,最近的研究表明,這種基於隨機選擇的 few-shot 學習在不同的上下文範例選擇上可能會表現得非常不穩定。在處理類似 TabMWP 這樣的複雜推理問題時,隨機選擇的效果可能會更差,因為其問題涉及不同類型和格式的表格。

為了解決這個問題,作者提出了一種改進方法:透過Policy Gradient 進行提示學習,從少量的訓練資料中學習選擇上下文範例,稱為PromptPG。如圖2 所示,策略網路學習從候選池(candidate examples)中找到最佳的in-context example,其最佳化目標是在與GPT-3 環境互動時最大化給定訓練範例(training example)的預測獎勵。選擇範例的策略網路是一個基於固定參數的 BERT 語言模型和一個參數可學習的單層神經網路。在完成最佳化學習後,PromptPG 可以對不同的測試題目,動態地從候選範例中選出不同的最優範例,從而最大化提高 GPT-3 的推理表現。

PromptPG:當強化學習遇見大規模語言模型

以下為 PromptPG 的學習演算法。

PromptPG:當強化學習遇見大規模語言模型

3、實驗與分析

PromptPG:當強化學習遇見大規模語言模型

預訓練與微調

表3 比較了PromptPG 和不同基準在TabMWP 資料集上的結果。可以看到,TAPEX 由於在表格資料上進行了預訓練,在相似參數量的前提下,其比 UnifiedQA 的表現更好。對於 TAPEX 和 UnifiedQA 來說,提高模型的參數量都可以提高預測的準確性。此外,在 TabMWP 上進行模型的微調也可以大大提升預測的準確性。

大規模語言模型

#GPT-3 在沒有任何微調的情況下(Zero-shot GPT- 3),可以取得與微調過的UnifiedQA 以及TAPEX 模型相近的準確性。如果 Few-shot GPT-3 模型隨機選擇兩個 in-context 範例作為 GPT-3 的提示,其相比 Zero-shot GPT-3 可以進一步提升 0.17%。透過讓 Few-shot GPT-3 在產生最終答案前產生多步驟的中間步驟(Few-shot-CoT GPT-3),研究人員可以獲得最優的基準模型,其準確率達到了 62.92%。

PromptPG

區別於隨機選擇in-context 範例,本文提出的PromptPG 透過Policy Gradient 訓練一個策略網路來選擇更合適的in-context 範例,在TabMWP 上取得了最高的預測結果(68.23%),其平均預測準確率超過最佳基準模型(Few-shot-CoT GPT-3)5.31%。值得注意的是,對於幾乎所有的問題類型、答案類型和問題難度,PromptPG 都展現了其在預測準確率上的優勢。儘管如此,PromptPG 距離人類 90.22% 的表現還有很大的提升空間。

消融實驗

PromptPG:當強化學習遇見大規模語言模型

#表4 表明,TabMWP 的所有輸入元素(問題文字、表格資訊、選項資訊)都對正確回答問題至關重要。只有所有的問題元素作為輸入訊息,Zero-shot GPT-3 才取得了其相對最高的平均預測準確率(59.50%)。

不同的範例選擇

PromptPG:當強化學習遇見大規模語言模型

#作為對比實驗,研究者還比較了其他不同範例選擇的方法。如表 5 所示,選擇與測驗問題相同的題型或答案類型可以幫助模型找到更相關的範例,並提高答案的準確性。選擇最複雜的範例則並不能穩定地提高回答準確性。在候選範例中固定選擇兩個最好的範例,可以小幅度提高準確性,並降低變異數。選擇語意上最接近測試問題的範例可以達到最接近 PromptPG 方法的準確性。整體來說,PromptPG 全面展現了其在提升預測準確度和降低預測變異數的優勢。

下圖展示了 PromptPG 選擇的範例以及最終的預測結果。可以看到,PromptPG 方法可以選擇與測試題目具有類似的數學能力的範例,從而提高 Few-shot GPT-3 的推理表現。

PromptPG:當強化學習遇見大規模語言模型

#預測成功的範例

#以下展示了PromptPG 對一個自由文本問題的正確答案。這個問題要求將表格中的八個數字分別進行加法和除法計算以獲得平均值。

PromptPG:當強化學習遇見大規模語言模型

在如下的例子中,模型被要求理解一個稅務報告,並計算扣稅後的工資。

PromptPG:當強化學習遇見大規模語言模型

以下展示了 PromptPG 對多選題問題的正確預測。給定的表格一共有 9 行和 6 列。模型成功地定位到了表格中的目標單元格,並進行多步驟推理以預測正確答案。

PromptPG:當強化學習遇見大規模語言模型

在以下的例子中,模型需要比較預算和總成本,以驗證 Ariana 是否有足夠的錢。

PromptPG:當強化學習遇見大規模語言模型

#預測失敗的範例

以下展示了PromptPG 對自由文本問題的錯誤預測。模型檢索到了錯誤的玫瑰石英價格,從而錯誤計算了三個物品的成本總和。

PromptPG:當強化學習遇見大規模語言模型

在以下的例子中,問題提供了一個抽象的莖葉表。模型無法理解這個特定領域的表格,並且缺乏高級邏輯推理能力從而得到了錯誤的答案。

PromptPG:當強化學習遇見大規模語言模型

以下的例子表明,現有的模型似乎不具有對數字排序的能力。

PromptPG:當強化學習遇見大規模語言模型

在以下的例子中,表格中沒有出現與問題提到的當前時間完全一致的時間,因此模型無法準確定位到下一站的出發時間。在

PromptPG:當強化學習遇見大規模語言模型

以下的範例中,模型很難準確完成一長串數字的算術運算。

PromptPG:當強化學習遇見大規模語言模型

4、結論與展望

#作者提出了TabMWP,這是第一個針對表格情境的數學問題求解的大規模資料集。 TabMWP 包含了 38,431 個開放領域的問題,其中包括兩種問題類型和五種答案類型,每個問題都標註了多步驟的解答過程。作者使用了最先進的 QA 和 TableQA 方法,在預訓練和微調設定下對 TabMWP 進行了全面的實驗,以及使用大型預訓練語言模型 GPT-3 進行評估。作者進一步提出了一種全新的強化學習方法 PromptPG,該方法利用 Policy Gradient 學習從訓練資料中選擇最優的實例用於提示用於 GPT-3 模型。實驗結果表明,與隨機選擇相比,PromptPG 的性能明顯優於現有的基線,並且減少了預測中的性能不穩定性。

以上是PromptPG:當強化學習遇見大規模語言模型的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除