搜尋
首頁科技週邊人工智慧使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

Jan 25, 2024 pm 12:21 PM
人工智慧大型語言模型

2024年是大型語言模型(LLM)快速發展的一年。在LLM的訓練中,對齊方法是一個重要的技術手段,其中包括監督微調(SFT)和依賴人類偏好的人類回饋強化學習(RLHF)。這些方法在LLM的發展中起到了至關重要的作用,但是對齊方法需要大量的人工註釋資料。面對這項挑戰,微調成為一個充滿活力的研究領域,研究人員積極致力於開發能夠有效利用人類資料的方法。因此,對齊方法的發展將推動LLM技術的進一步突破。

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

加州大學最近進行了一項研究,介紹了一種名為SPIN(Self Play fIne tuNing)的新技術。 SPIN借鑒了AlphaGo Zero和AlphaZero等遊戲中成功的自我對弈機制,使LLM(Language Learning Model)能夠參與自我遊戲。這項技術消除了對專業註釋者的需求,無論是人類還是更高級的模型(如GPT-4)。 SPIN的訓練過程包括訓練一個新的語言模型,並透過一系列迭代來區分它自己產生的反應和人類生成的反應。其最終目標是發展出一種語言模型,使其產生的回答與人類的回答沒有區別。這項研究的目的在於進一步提升語言模型的自我學習能力,使其更接近人類的表達和思考方式。這項研究的成果有望為自然語言處理領域的發展帶來新的突破。

自我博弈

自我博弈是一種學習技術,透過對抗自身副本來增加學習環境的挑戰性和複雜性。這種方法允許代理與自己的不同版本進行交互,從而提高自身的能力。 AlphaGo Zero是一個成功的自我遊戲案例。

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

自我博弈在多智能體強化學習(MARL)中已被證實是有效的方法。然而,將其應用於大型語言模型(LLM)的增強是一種新的方法。透過在大型語言模型中應用自我博弈,可以進一步提高它們的能力,使其產生更連貫、資訊豐富的文本。此方法可望推動語言模型的進一步發展和提升。

自我遊戲可應用於競爭或合作環境。競爭中,演算法副本相互競爭達到目標;合作中,副本一起工作以實現共同目標。可與監督學習、強化學習等技術結合,提升性能。

SPIN

SPIN就像是雙人遊戲。在這個遊戲中:

主模型(新LLM)的角色是學習區分語言模型(LLM)產生的回應和人類創建的回應。每次迭代中,主模型都在積極訓練LLM以提高其識別和區分反應的能力。

對手模型(舊LLM)的任務是產生與人類產生的反應相似的結果。它是透過上一輪迭代的LLM產生的,利用自我博弈機制根據過去的知識來產生輸出。對手模型的目標是創造逼真的反應,以至於新的LLM無法確定它是由機器產生的。

這個流程是不是很像GAN,但還是不太一樣

SPIN的動態涉及使用監督微調(SFT)資料集,該資料集由輸入(x)和輸出(y )對組成。這些範例由人工註釋,並作為訓練主模型識別類人反應的基礎。一些公開的SFT資料集包括Dolly15K、Baize、Ultrachat等。

主模型的訓練

為了訓練主模型區分語言模型(LLM)和人類反應,SPIN使用了一個目標函數。這個函數測量真實數據和對手模型產生的反應之間的預期值差距。主模型的目標是最大化這一期望值差距。這包括將高值分配給與真實數據的回應配對的提示,並將低值分配給由對手模型產生的回應配對。這個目標函數被表述為最小化問題。

主模型的工作是最小化損失函數,即衡量來自真實資料的配對分配值與來自對手模型反應的配對分配值之間的差異。在整個訓練過程中,主模型調整其參數以最小化該損失函數。這個迭代過程一直持續下去,直到主模型能夠熟練地有效區分LLM的反應和人類的反應。

對手模型的更新

更新對手模型涉及改進主模型的能力,他們在訓練時已經學會區分真實資料和語言模型反應。隨著主模型的改進及其對特定函數類別的理解,我們還需要更新如對手模型的參數。當主玩家面對相同的提示時,它便會使用學習所得到的辨別能力來評估它們的價值。

對手模型玩家的目標是增強語言模型,使其反應與主玩家的真實數據無法區分。這就需要設定一個流程來調整語言模型的參數。目的是在保持穩定性的同時,最大限度地提高主模型對語言模型反應的評估。這涉及到一種平衡行為,確保改進不會偏離原始語言模型太遠。

聽著有點亂,我們簡單總結下:

訓練的時候只有一個模型,但是將模型分為前一輪的模型(舊LLM/對手模型)和主模型(正在訓練的),使用正在訓練的模型的輸出與上一輪模型的輸出作為對比,來優化目前模型的訓練。但這裡就要求我們必須要有一個訓練好的模型作為對手模型,所以SPIN演算法只適合在訓練結果上微調。

SPIN演算法

SPIN從預先訓練的模型產生合成資料。然後使用這些合成資料對新任務上的模型進行微調。

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

上面時原始論文中Spin演算法的偽代碼,看著有點難理解,我們透過Python來復現更好地解釋它是如何運作的。

1、初始化參數與SFT資料集

原文採用Zephyr-7B-SFT-Full作為基本模型。對於資料集,他們使用了更大的Ultrachat200k語料庫的子集,該語料庫由使用OpenAI的Turbo api生成的大約140萬個對話組成。他們隨機抽取了50k個提示,並使用基本模型來產生合成響應。

# Import necessary libraries from datasets import load_dataset import pandas as pd  # Load the Ultrachat 200k dataset ultrachat_dataset = load_dataset("HuggingFaceH4/ultrachat_200k")  # Initialize an empty DataFrame combined_df = pd.DataFrame()  # Loop through all the keys in the Ultrachat dataset for key in ultrachat_dataset.keys():# Convert each dataset key to a pandas DataFrame and concatenate it with the existing DataFramecombined_df = pd.concat([combined_df, pd.DataFrame(ultrachat_dataset[key])])  # Shuffle the combined DataFrame and reset the index combined_df = combined_df.sample(frac=1, random_state=123).reset_index(drop=True)  # Select the first 50,000 rows from the shuffled DataFrame ultrachat_50k_sample = combined_df.head(50000)

作者的提示範本「

Instruction: {prompt}\n\n

Response:」

# for storing each template in a list templates_data = []  for index, row in ultrachat_50k_sample.iterrows():messages = row['messages'] # Check if there are at least two messages (user and assistant)if len(messages) >= 2:user_message = messages[0]['content']assistant_message = messages[1]['content'] # Create the templateinstruction_response_template = f"### Instruction: {user_message}\n\n### Response: {assistant_message}" # Append the template to the listtemplates_data.append({'Template': instruction_response_template})  # Create a new DataFrame with the generated templates (ground truth) ground_truth_df = pd.DataFrame(templates_data)

然後得到了類似下面的資料:使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

SPIN演算法透過迭代更新語言模型(LLM)的參數使其與地面真實響應保持一致。這個過程一直持續下去,直到很難區分生成的反應和真實情況,從而實現高水準的相似性(降低損失)。

SPIN演算法有兩個迴圈。內部循環基於我們正在使用的樣本數量運行,外部循環總共運行了3次迭代,因為作者發現模型的性能在此之後沒有變化。採用Alignment Handbook庫作為微調方法的程式碼庫,結合DeepSpeed模組,降低了訓練成本。他們用RMSProp優化器訓練Zephyr-7B-SFT-Full,所有迭代都沒有權重衰減,就像通常用於微調llm一樣。全域批次大小設定為64,使用bfloat16精度。迭代0和1的峰值學習率設定為5e-7,迭代2和3的峰值學習率隨著循環接近自播放微調的結束而衰減為1e-7。最後選擇β = 0.1,最大序列長度設定為2048個標記。以下就是這些參數

 # Importing the PyTorch library import torch  # Importing the neural network module from PyTorch import torch.nn as nn  # Importing the DeepSpeed library for distributed training import deepspeed  # Importing the AutoTokenizer and AutoModelForCausalLM classes from the transformers library from transformers import AutoTokenizer, AutoModelForCausalLM  # Loading the zephyr-7b-sft-full model from HuggingFace tokenizer = AutoTokenizer.from_pretrained("alignment-handbook/zephyr-7b-sft-full") model = AutoModelForCausalLM.from_pretrained("alignment-handbook/zephyr-7b-sft-full")  # Initializing DeepSpeed Zero with specific configuration settings deepspeed_config = deepspeed.config.Config(train_batch_size=64, train_micro_batch_size_per_gpu=4) model, optimizer, _, _ = deepspeed.initialize(model=model, config=deepspeed_config, model_parameters=model.parameters())  # Defining the optimizer and setting the learning rate using RMSprop optimizer = deepspeed.optim.RMSprop(optimizer, lr=5e-7)  # Setting up a learning rate scheduler using LambdaLR from PyTorch scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.2 ** epoch)  # Setting hyperparameters for training num_epochs = 3 max_seq_length = 2048 beta = 0.1

2、產生合成資料(SPIN演算法內循環)

這個內部迴圈負責產生需要與真實資料保持一致的回應,也就是一個訓練批次的程式碼

# zephyr-sft-dataframe (that contains output that will be improved while training) zephyr_sft_output = pd.DataFrame(columns=['prompt', 'generated_output'])  # Looping through each row in the 'ultrachat_50k_sample' dataframe for index, row in ultrachat_50k_sample.iterrows():# Extracting the 'prompt' column value from the current rowprompt = row['prompt'] # Generating output for the current prompt using the Zephyr modelinput_ids = tokenizer(prompt, return_tensors="pt").input_idsoutput = model.generate(input_ids, max_length=200, num_beams=5, no_repeat_ngram_size=2, top_k=50, top_p=0.95) # Decoding the generated output to human-readable textgenerated_text = tokenizer.decode(output[0], skip_special_tokens=True) # Appending the current prompt and its generated output to the new dataframe 'zephyr_sft_output'zephyr_sft_output = zephyr_sft_output.append({'prompt': prompt, 'generated_output': generated_text}, ignore_index=True)

這是一個提示的真實值和模型輸出的範例。 使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

新的df zephyr_sft_output,其中包含提示及其透過基本模型Zephyr-7B-SFT-Full產生的對應輸出。

3、更新規則

在編碼最小化問題之前,理解如何計算llm產生的輸出的條件機率分佈是至關重要的。原論文使用馬可夫過程,其中條件機率分佈pθ (y∣x)可透過分解表示為:使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

這種分解意味著給定輸入序列的輸出序列的機率可以透過將給定輸入序列的每個輸出標記與前一個輸出標記的機率相乘來計算。例如輸出序列為“I enjoy reading books”,輸入序列為“I enjoy”,則在給定輸入序列的情況下,輸出序列的條件機率可以計算為:使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

馬可夫製程條件機率將用於計算真值和Zephyr LLM響應的機率分佈,然後用於計算損失函數。但首先我們需要對條件機率函數進行編碼。

 # Conditional Probability Function of input text def compute_conditional_probability(tokenizer, model, input_text):# Tokenize the input text and convert it to PyTorch tensorsinputs = tokenizer([input_text], return_tensors="pt") # Generate text using the model, specifying additional parametersoutputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) # Assuming 'transition_scores' is the logits for the generated tokenstransition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True) # Get the length of the input sequenceinput_length = inputs.input_ids.shape[1] # Assuming 'transition_scores' is the logits for the generated tokenslogits = torch.tensor(transition_scores) # Apply softmax to obtain probabilitiesprobs = torch.nn.functional.softmax(logits, dim=-1) # Extract the generated tokens from the outputgenerated_tokens = outputs.sequences[:, input_length:] # Compute conditional probabilityconditional_probability = 1.0for prob in probs[0]:token_probability = prob.item()conditional_probability *= token_probability return conditional_probability

損失函數它包含四個重要的條件機率變數。這些變數中的每一個都取決於基礎真實資料或先前創建的合成資料。 使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

######

而lambda是一个正则化参数,用于控制偏差。在KL正则化项中使用它来惩罚对手模型的分布与目标数据分布之间的差异。论文中没有明确提到lambda的具体值,因为它可能会根据所使用的特定任务和数据集进行调优。

 def LSPIN_loss(model, updated_model, tokenizer, input_text, lambda_val=0.01):# Initialize conditional probability using the original model and input textcp = compute_conditional_probability(tokenizer, model, input_text) # Update conditional probability using the updated model and input textcp_updated = compute_conditional_probability(tokenizer, updated_model, input_text) # Calculate conditional probabilities for ground truth datap_theta_ground_truth = cp(tokenizer, model, input_text)p_theta_t_ground_truth = cp(tokenizer, model, input_text) # Calculate conditional probabilities for synthetic datap_theta_synthetic = cp_updated(tokenizer, updated_model, input_text)p_theta_t_synthetic = cp_updated(tokenizer, updated_model, input_text) # Calculate likelihood ratioslr_ground_truth = p_theta_ground_truth / p_theta_t_ground_truthlr_synthetic = p_theta_synthetic / p_theta_t_synthetic # Compute the LSPIN lossloss = lambda_val * torch.log(lr_ground_truth) - lambda_val * torch.log(lr_synthetic) return loss

如果你有一个大的数据集,可以使用一个较小的lambda值,或者如果你有一个小的数据集,则可能需要使用一个较大的lambda值来防止过拟合。由于我们数据集大小为50k,所以可以使用0.01作为lambda的值。

4、训练(SPIN算法外循环)

这就是Pytorch训练的一个基本流程,就不详细解释了:

# Training loop for epoch in range(num_epochs): # Model with initial parametersinitial_model = AutoModelForCausalLM.from_pretrained("alignment-handbook/zephyr-7b-sft-full") # Update the learning ratescheduler.step() # Initialize total loss for the epochtotal_loss = 0.0 # Generating Synthetic Data (Inner loop)for index, row in ultrachat_50k_sample.iterrows(): # Rest of the code ... # Output == prompt response dataframezephyr_sft_output # Computing loss using LSPIN functionfor (index1, row1), (index2, row2) in zip(ultrachat_50k_sample.iterrows(), zephyr_sft_output.iterrows()):# Assuming 'prompt' and 'generated_output' are the relevant columns in zephyr_sft_outputprompt = row1['prompt']generated_output = row2['generated_output'] # Compute LSPIN lossupdated_model = model # It will be replacing with updated modelloss = LSPIN_loss(initial_model, updated_model, tokenizer, prompt) # Accumulate the losstotal_loss += loss.item() # Backward passloss.backward() # Update the parametersoptimizer.step() # Update the value of betaif epoch == 2:beta = 5.0

我们运行3个epoch,它将进行训练并生成最终的Zephyr SFT LLM版本。官方实现还没有在GitHub上开源,这个版本将能够在某种程度上产生类似于人类反应的输出。我们看看他的运行流程

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

表现及结果

SPIN可以显著提高LLM在各种基准测试中的性能,甚至超过通过直接偏好优化(DPO)补充额外的GPT-4偏好数据训练的模型。

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

当我们继续训练时,随着时间的推移,进步会变得越来越小。这表明模型达到了一个阈值,进一步的迭代不会带来显著的收益。这是我们训练数据中样本提示符每次迭代后的响应。

使用SPIN技術進行自我博弈微調訓練的LLM的最佳化

以上是使用SPIN技術進行自我博弈微調訓練的LLM的最佳化的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:51CTO.COM。如有侵權,請聯絡admin@php.cn刪除
及時工程中的思想圖是什麼及時工程中的思想圖是什麼Apr 13, 2025 am 11:53 AM

介紹 在迅速的工程中,“思想圖”是指使用圖理論來構建和指導AI的推理過程的新方法。與通常涉及線性S的傳統方法不同

優化您的組織與Genai代理商的電子郵件營銷優化您的組織與Genai代理商的電子郵件營銷Apr 13, 2025 am 11:44 AM

介紹 恭喜!您經營一家成功的業務。通過您的網頁,社交媒體活動,網絡研討會,會議,免費資源和其他來源,您每天收集5000個電子郵件ID。下一個明顯的步驟是

Apache Pinot實時應用程序性能監視Apache Pinot實時應用程序性能監視Apr 13, 2025 am 11:40 AM

介紹 在當今快節奏的軟件開發環境中,確保最佳應用程序性能至關重要。監視實時指標,例如響應時間,錯誤率和資源利用率可以幫助MAIN

Chatgpt擊中了10億用戶? Openai首席執行官說:'短短幾週內翻了一番Chatgpt擊中了10億用戶? Openai首席執行官說:'短短幾週內翻了一番Apr 13, 2025 am 11:23 AM

“您有幾個用戶?”他扮演。 阿爾特曼回答說:“我認為我們上次說的是每週5億個活躍者,而且它正在迅速增長。” “你告訴我,就像在短短幾週內翻了一番,”安德森繼續說道。 “我說那個私人

pixtral -12b:Mistral AI'第一個多模型模型 - 分析Vidhyapixtral -12b:Mistral AI'第一個多模型模型 - 分析VidhyaApr 13, 2025 am 11:20 AM

介紹 Mistral發布了其第一個多模式模型,即Pixtral-12b-2409。該模型建立在Mistral的120億參數Nemo 12B之上。是什麼設置了該模型?現在可以拍攝圖像和Tex

生成AI應用的代理框架 - 分析Vidhya生成AI應用的代理框架 - 分析VidhyaApr 13, 2025 am 11:13 AM

想像一下,擁有一個由AI驅動的助手,不僅可以響應您的查詢,還可以自主收集信息,執行任務甚至處理多種類型的數據(TEXT,圖像和代碼)。聽起來有未來派?在這個a

生成AI在金融部門的應用生成AI在金融部門的應用Apr 13, 2025 am 11:12 AM

介紹 金融業是任何國家發展的基石,因為它通過促進有效的交易和信貸可用性來推動經濟增長。交易的便利和信貸

在線學習和被動攻擊算法指南在線學習和被動攻擊算法指南Apr 13, 2025 am 11:09 AM

介紹 數據是從社交媒體,金融交易和電子商務平台等來源的前所未有的速度生成的。處理這種連續的信息流是一個挑戰,但它提供了

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
3 週前By尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解鎖Myrise中的所有內容
4 週前By尊渡假赌尊渡假赌尊渡假赌

熱工具

MinGW - Minimalist GNU for Windows

MinGW - Minimalist GNU for Windows

這個專案正在遷移到osdn.net/projects/mingw的過程中,你可以繼續在那裡關注我們。 MinGW:GNU編譯器集合(GCC)的本機Windows移植版本,可自由分發的導入函式庫和用於建置本機Windows應用程式的頭檔;包括對MSVC執行時間的擴展,以支援C99功能。 MinGW的所有軟體都可以在64位元Windows平台上運作。

MantisBT

MantisBT

Mantis是一個易於部署的基於Web的缺陷追蹤工具,用於幫助產品缺陷追蹤。它需要PHP、MySQL和一個Web伺服器。請查看我們的演示和託管服務。

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用