進入正文之前,先考慮像ChatGPT 這樣的Transformer 語言模型(LM)的prompt:
隨著每天產生數百萬用戶和查詢,ChatGPT 使用自註意力機制對prompt 進行反覆編碼,其時間和記憶體複雜度隨輸入長度呈二次方增長。快取 prompt 的 transformer 啟動可以防止部分重新計算,但隨著快取 prompt 數量的增加,這種策略仍然會產生很大的記憶體和儲存成本。在大規模情況下,即使 prompt 長度稍微減少一點,也可能會帶來運算、記憶體和儲存空間的節省,同時還可以讓使用者將更多內容放入 LM 有限的上下文視窗中。
那麼。應該如何降低 prompt 的成本呢?典型的方法是微調或蒸餾模型,使其在沒有 prompt 的情況下表現得與原始模型相似,或許還可以使用參數高效的自適應方法。然而,這種方法的一個基本缺點是每次需要為新的 prompt 重新訓練模型(下圖 1 中間所示)。
#本文中,史丹佛大學的研究者提出了gisting 模型(上圖1 底部),它將任意prompt 壓縮成一組更小的虛擬“Gist” token,類似於前綴微調。然而,前綴微調需要透過梯度下降為每個任務學習 prefix,而 Gisting 採用元學習方法,僅透過 prompt 預測 Gist prefix,而不需要為每個任務進行 prefix 學習。這樣可以攤銷每個任務 prefix 學習的成本,使得在沒有額外訓練的情況下泛化到未知的指令。
此外,由於「Gist」token 比完整 prompt 要短得多,因此 Gisting 允許 prompt 被壓縮、快取和重複使用,以提高運算效率。
#論文網址:https://arxiv.org/pdf/2304.08467 v1.pdf
研究者提出了一個非常簡單的方法來學習指令遵循的gist 模型:簡單地進行指令微調,在prompt 後插入gish token,修改後的注意力掩膜阻止gist token 後的token 參考gist token 前的token。這使得模型同時學習 prompt 壓縮和指令遵循,而無需額外的訓練成本。
在decodr-only(LLaMA-7B)和encoder-decoder(FLAN-T5-XXL)LM 上,gisting 可實現高達26 倍的即時壓縮率,同時保持與原始模型相似的輸出品質。這使得推理過程中 FLOPs 減少了 40%,延遲加速了 4.2%,與傳統的 prompt 快取方法相比,儲存成本大大降低。
研究者首先在指令微調的背景下描述 gisting。對於指令遵循資料集,t 表示用自然語言prompt 編碼的任務(例如將此翻譯成法語),x 表示任務的(可選)輸入(例如The cat),y 表示期望的輸出(例如Le chat)。指令微調的目的是透過連接 t 和 x,然後讓通常預先訓練的語言模型自回歸地預測 y,從而學習分佈 pLM(y | t,x)。推理時可以使用新的任務 t 和輸入 x 進行 prompt,從模型中解碼以獲得預測結果。
然而,連接 t 和 x 的這種模式具有缺點:基於 Transformer 的 LM 具有有限的上下文窗口,其受架構或計算能力所限。後者特別難解決,因為自註意力隨輸入長度呈二次方擴展。因此很長的 prompt,尤其是那些被反覆重用的 prompt,計算效率低。有哪些選項可以用來降低 prompt 的成本呢?
一種簡單的方法是針對特定任務t 進行LM 微調,即給定包含僅在任務t 下的輸入/ 輸出範例的資料集,可以學習一個專門的,它更快,因為不需要考慮t。
更好的是,prefix/prompt 微調或 adapter 等參數高效微調方法能夠以比全面微調低得多的成本實現相同的目的。然而仍然存在問題:必須至少儲存每個任務的一部分模型權重,更重要的是,對於每個任務 t,必須收集相應的輸入 / 輸出對資料集 D^t 並重新訓練模型。
Gisting 是一種不同的方法,它攤銷了兩部分成本:(1)在t 上條件化p_LM 的推理時間成本,(2)學習每個t 的新p^t_LM 的訓練時間成本。其想法是在微調期間學習 t 的壓縮版本 G (t),使得從 p_G (y | G (t),x) 進行推理比從 p_LM (y|t,x) 更快。
在LM 術語中,G (t) 將是一組「虛擬」的Gist token,其數量比t 中的token 少,但仍會在LM 中引起類似的行為。接著可以快取並重複使用 G (t) 上的 transformer 啟動(例如鍵和值矩陣)以提高計算效率。重要的是,研究者希望 G 可以泛化到未見過的任務:給定一個新任務 t,則可以預測並使用相應的 Gist 激活 G (t) 而無需進行任何額外訓練。
上文描述了Gisting 的一般框架,接下來將探討學習此類模型的極簡單方法:使用LM 本身用作Gist 預測器G。這不僅利用了 LM 中的預先存在知識,而且允許透過簡單地執行標準指令微調來學習 gisting 並修改 Transformer 注意力掩膜來增強 prompt 壓縮。這意味著 Gisting 不會產生額外訓練成本,只需要基於標準指示微調即可!
具體來說,在模型詞彙表和嵌入矩陣中添加一個特殊的 gist token,類似於此類模型中常見的句子開頭 / 結尾 token。然後對於給定的(任務,輸入)元組(t,x),使用(t, g_1, . . . , g_k, x) 中一組k 個連續的gist token 將t 和x 連接在一起,例如。這個序列被輸入到模型中,有一個限制,即在 gist token 之後的輸入 token 不能參考先前的 prompt token(但它們可以參考 gist token)。這會強制模型將 prompt 中的信息壓縮成 gist token,因為輸入 x (輸出 y) 無法處理 prompt t。
下圖 2 展示了所需的變更。對於 GPT-3 或 LLaMA 等通常採用自回歸因果注意力掩膜的 decoder-only LM,只需 mask out 圖 2a 所示的三角形左下角。對於具有雙向編碼器和自回歸解碼器的 encoder-decoder LM,則需要進行兩項修改(圖 2b 所示)。
首先,在通常沒有遮罩的編碼器中,阻止輸入 token x 參考 prompt token t。但也必須防止 prompt t 和 gist token g_i 參考輸入 token x,否則編碼器將根據輸入學習不同的 gist 表示。最後解碼器正常運行,除了在交叉注意力期間,這時需要阻止解碼器參考 prompt token t。
#對於不同數量的gist token,LLaMA- 7B 和FLAN-T5-XXL 的ROUGE-L 和ChatGPT 評估結果如下圖3 所示。
模型通常对 gist token 的数量 k 不敏感:将 prompt 压缩到单个 token 并不会导致显著性能下降。事实上,在某些情况下,过多的 gist token 会损害性能 (例如 LLaMA-7B, 10 gist tokens),这可能是因为增加的容量使训练分布过拟合。因此,研究者在下表 1 中给出了单 token 模型的具体数值,并在剩余实验中使用单个 gist 模型。
在见过的指令上,gist 模型获得了与其对应阳性对照模型几乎相同的 ROUGE 和 ChatGPT 性能,在 LLaMA-7B FLANT5-XXL 上的胜率分别为 48.6% 和 50.8%。这里研究者最感兴趣的是它们在未见过任务上的泛化能力,这需要通过另外两个数据集来衡量的。
在 Alpaca 训练数据集中未见过的 prompt 中,可以看到 gist 模型在未见过 prompt 上有着强大的泛化能力:与对照组相比,分别有 49.7%(LLaMA)和 46.2%(FLAN-T5)的胜率。在最具挑战性的 OOD Human split 上,gist 模型的胜率略微下降,分别为 45.8%(LLaMA)和 42.5%(FLANT5)。
本文的目的是让 gist 模型紧密地模仿原始模型的功能,因此有人可能会问究竟什么时候 gist 模型与对照组无差别。下图 4 说明了这种情况发生的频率:对于已见过任务(但是未见过的输入),gist 模型几乎有一半的时间与对照组不相上下。对于未见过的任务,这一数字下降到了 20-25%。对于 OOD Human 任务,这一数字又下降到 10%。无论如何,gist 模型输出的质量是很高的。
总的来说,这些结果表明,gist 模型可以可靠地压缩 prompt,甚至在训练分布之外的某些 prompt 上也可以做到这一点,特别是像 LLaMA 这样的 decoder-only 因果 LM。FLAN-T5 等 encoder-decoder 模型表现略差,一个可能的原因是 gist 掩膜抑制了编码器中的双向注意力流,这比仅 mask 自回归解码器的一部分 history 更具挑战性。未来需要进一步的工作来研究这个假设。
最后,回到这项工作的核心动机之一:gisting 可以带来什么样的效率提升?
下表 2 展示了使用 PyTorch 2.0 分析器对模型进行单次前向传递的结果(即使用单个输入 token 的自回归解码的一步),并对 Human eval split 中的 252 个指令取平均值。与未经优化的模型相比,gist 缓存显著提高了效率。两种模型的 FLOPs 节约率达到了 40%,时钟时间降低了 4-7%。
然而更重要的是,与指令缓存相比,gist 缓存有着除延迟之外的关键优势:将 26 个 token 压缩为 1 个可以在输入上下文窗口中腾出更多空间,这受到绝对位置嵌入或者 GPU VRAM 的限制。特别是对于 LLaMA-7B,KV 缓存中的每个 token 需要 1.05MB 的存储空间。尽管在测试的 prompt 长度下,KV 缓存相对于 LLaMA-7B 推断所需的内存总贡献微不足道,但一个越来越常见的场景是开发人员在大量用户之间缓存许多 prompt,存储成本很快就会增加。在存储空间相同的情况下,gist 缓存能比完整指令缓存多 26 倍的 prompt。
以上是將26個token壓縮成1個新方法,極致節省ChatGPT輸入框空間的詳細內容。更多資訊請關注PHP中文網其他相關文章!