搜尋
首頁科技週邊人工智慧持續學習常用六種方法總結:使ML模型適應新資料的同時保持舊資料的效能

持續學習是指在不忘記從前面的任務中獲得的知識的情況下,按順序學習大量任務的模型。這是一個重要的概念,因為在監督學習的前提下,機器學習模型被訓練為針對給定資料集或資料分佈的最佳函數。而在現實環境中,資料很少是靜態的,可能會改變。當面對不可見的資料時,典型的ML模型可能會效能下降。這種現像被稱為災難性遺忘。

持續學習常用六種方法總結:使ML模型適應新資料的同時保持舊資料的效能

解決這類問題的常用方法是在包含新舊資料的新的更大資料集上對整個模型進行再訓練。但是這種做法往往代價高昂。所以有一個ML研究領域正在研究這個問題,基於該領域的研究,本文將討論6種方法,使模型可以在保持舊的性能的同時適應新數據,並避免需要在整個數據集(舊新)上進行重新訓練。

Prompt

Prompt 想法源自於GPT 3的提示(短序列的單字)可以幫助驅動模型更好地推理和回答。所以在本文中將Prompt 翻譯為提示。提示調優是指使用小型可學習的提示,並將其與實際輸入一起作為模型的輸入。這允許我們只在新資料上訓練提供提示的小模型,而無需再訓練模型權重。

具體來說,我選擇了使用提示進行基於文本的密集檢索的例子,這個例子改編自Wang的文章《Learning to Prompt for continuous Learning》。

論文的作者使用下圖描述了他們的想法:

持續學習常用六種方法總結:使ML模型適應新資料的同時保持舊資料的效能

#實際編碼的文字輸入用作從提示池中識別最小匹配對的key。在將這些標識的提示輸入到模型之前,首先將它們新增至未編碼的文字嵌入。這樣做的目的是訓練這些提示來表示新的任務,同時保持舊的模型不變,這裡提示的很小,大概每個提示只有20個令牌。

class PromptPool(nn.Module):
def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):
super().__init__()
self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()
self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()
 
self.length = length
self.hidden = hidden_size
self.n = N
 
nn.init.xavier_normal_(self.pool)
nn.init.xavier_normal_(self.keys)
 
def init_weights(self, embedding):
pass
 
# function to select from pool based on index
def concat(self, indices, input_embeds):
subset = self.pool[indices, :] # 2, 2, 20, 768
 
subset = subset.to("cuda:0").reshape(indices.size(0),
self.n*self.length,
self.hidden) # 2, 40, 768
 
return torch.cat((subset, input_embeds), 1)
 
# x is cls output
def query_fn(self, x):
 
# encode input x to same dim as key using cosine
x = x / x.norm(dim=1)[:, None]
k = self.keys / self.keys.norm(dim=1)[:, None]
 
scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))
 
# get argmin
subsets = torch.topk(scores, self.n, 1, False).indices # k smallest
 
return subsets
 
 pool = PromptPool()

然後我們使用的經過訓練的舊數據模型,訓練新的數據,這裡只訓練提示部分的權重。

def train():
count = 0
print("*********** Started Training *************")
 
start = time.time()
for epoch in range(40):
model.eval()
pool.train()
 
optimizer.zero_grad(set_to_none=True)
lap = time.time()
 
for batch in iter(train_dataloader):
count += 1
q, p, train_labels = batch
 
queries_emb = model(input_ids=q['input_ids'].to("cuda:0"),
attention_mask=q['attention_mask'].to("cuda:0"))
passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),
attention_mask=p['attention_mask'].to("cuda:0"))
 
# pool
q_idx = pool.query_fn(queries_emb)
raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0"))
q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)
 
p_idx = pool.query_fn(passage_emb)
raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0"))
p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)
 
qattention_mask = torch.ones(batch_size, q.size(1))
pattention_mask = torch.ones(batch_size, p.size(1))
 
queries_emb = model.model(inputs_embeds=q,
attention_mask=qattention_mask.to("cuda:0")).last_hidden_state
passage_emb = model.model(inputs_embeds=p,
attention_mask=pattention_mask.to("cuda:0")).last_hidden_state
 
q_cls = queries_emb[:, pool.n*pool.length+1, :]
p_cls = passage_emb[:, pool.n*pool.length+1, :]
 
loss, ql, pl = calc_loss(q_cls, p_cls)
loss.backward()
 
optimizer.step()
optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Model Loss:", round(loss.item(),4), 
"| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), 
"| Took:", round(time.time() - lap), "secondsn")
 
lap = time.time()
 
if count % 40 == 0 and count > 0:
print("model saved")
torch.save(model.state_dict(), model_PATH)
torch.save(pool.state_dict(), pool_PATH)
 
if count == 4600: return
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

訓練完成後,後續的推理過程需要將輸入與檢索到的提示結合。例如這個例子得到了效能—93%的新資料提示池,而完全(舊 新)訓練為—94%。這與原論文中提到的表現類似。但是需要說明的一點是結果可能會因任務而不同,你應該嘗試實驗來知道什麼是最好的。

要使此方法成為值得考慮的方法,它必須能夠在舊資料上保留舊模型> 80%的效能,同時提示也應該幫助模型在新資料上獲得良好的效能。

這種方法的缺點是需要使用提示池,這會增加額外的時間。這也不是一個永久的解決方案,但目前來說是可行的,或許以後還會有新的方法出現。

Data Distillation

你可能聽說過知識蒸餾一詞,這是一種使用來自教師模型的權重來指導和訓練較小規模模型的技術。資料蒸餾(Data Distillation)的工作原理也類似,它是使用來自真實資料的權重來訓練較小的資料子集。因為資料集的關鍵訊號被提煉並濃縮為較小的資料集,我們對新資料的訓練只需要提供一些提煉的資料以保持舊的效能。

在此範例中,我將資料蒸餾應用於密集檢索(文字)任務。目前看沒有其他人在這個領域使用這種方法,所以結果可能不是最好的,但如果你在文字分類上使用這種方法應該會得到不錯的結果。

本質上,文本資料蒸餾的想法源自於 Li 的一篇題為 Data Distillation for Text Classification 的論文,該論文的靈感來自 Wang 的 Dataset Distillation,他對圖像資料進行了蒸餾。 Li 用下圖描述了文本資料蒸餾的任務:

持續學習常用六種方法總結:使ML模型適應新資料的同時保持舊資料的效能

根據論文,首先將一批蒸餾資料輸入到模型以更新其權重。然後使用真實數據評估更新後的模型,並將訊號反向傳播到蒸餾數據集。該論文在 8 個公共基準資料集上報告了良好的分類結果(> 80% 準確率)。

依照提出的想法,我做了一些小的改動,使用了一批蒸餾數據和多個真實數據。以下是為密集檢索訓練創建蒸餾數據的代碼:

class DistilledData(nn.Module):
def __init__(self, num_labels, M, q_len=64, hidden_size=768):
super().__init__()
self.num_samples = M
self.q_len = q_len
self.num_labels = num_labels
self.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768
 
# init using model embedding, xavier, or load from state dict
def init_weights(self, model, path=None):
if model:
self.data.requires_grad = False
print("Init weights using model embedding")
raw_embedding = model.model.get_input_embeddings()
soft_embeds = raw_embedding.weight[:, :].clone().detach()
nums = soft_embeds.size(0)
for i1 in range(self.num_labels):
for i2 in range(self.num_samples):
for i3 in range(self.q_len):
random_idx = random.randint(0, nums-1)
self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]
print(self.data.shape)
self.data.requires_grad = True
 
if not path:
nn.init.xavier_normal_(self.data)
else:
distilled_data.load_state_dict(torch.load(path), strict=False)
 
# function to sample a passage and positive sample as in the article, i am doing dense retrieval
def get_sample(self, label):
q_idx = random.randint(0, self.num_samples-1)
sampled_dist_q = self.data[label, q_idx, :, :]
 
p_idx = random.randint(0, self.num_samples-1)
while q_idx == p_idx:
p_idx = random.randint(0, self.num_samples-1)
sampled_dist_p = self.data[label, p_idx, :, :]
 
return sampled_dist_q, sampled_dist_p, q_idx, p_idx

這是將信號提取到蒸餾數據上的代碼

def distll_train(chunk_size=32):
count, times = 0, 0
print("*********** Started Training *************")
start = time.time()
lap = time.time()
 
for epoch in range(40):
distilled_data.train()
 
for batch in iter(train_dataloader):
count += 1
# get real query, pos, label, distilled data query, distilled data pos, ... from batch
q, p, train_labels, dq, dp, q_indexes, p_indexes = batch
 
for idx in range(0, dq['input_ids'].size(0), chunk_size):
model.train()
 
with torch.enable_grad():
# train on distiled data first
x1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
q_emb = model(inputs_embeds=x1.to("cuda:0"),
attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()
p_emb = model(inputs_embeds=x2.to("cuda:0"),
attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))
loss = default_loss(q_emb.to("cuda:0"), p_emb)
del q_emb, p_emb
 
loss.backward(retain_graph=True, create_graph=False)
state_dict = model.state_dict()
 
# update model weights
with torch.no_grad():
for idx, param in enumerate(model.parameters()):
if param.requires_grad and not param.grad is None:
param.data -= (param.grad*3e-5)
 
# real data
model.eval()
q_embs = []
p_embs = []
for k in range(0, len(q['input_ids']), chunk_size):
with torch.no_grad():
q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
q_embs.append(q_emb)
p_embs.append(p_emb)
q_embs = torch.cat(q_embs, 0)
p_embs = torch.cat(p_embs, 0)
r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))
del q_embs, p_embs
 
# distill backward
if count % 2 == 0:
d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = q_indexes
else:
d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = p_indexes
loss.detach()
r_loss.detach()
 
grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768
for i, k in enumerate(indexes):
grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") 
+ d_grad[0][i, :, :]
distilled_data.data.grad = grads
data_optimizer.step()
data_optimizer.zero_grad(set_to_none=True)
 
model.load_state_dict(state_dict)
model_optimizer.step()
model_optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", 
round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))
# print()
lap = time.time()
 
if count % 100 == 0:
torch.save(model.state_dict(), model_PATH)
torch.save(distilled_data.state_dict(), distill_PATH)
 
if loss < 0.1 and r_loss < 1:
times += 1
 
if times > 100:
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")
return
del loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dict
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

這裡省略了數據加載等代碼,訓練完蒸餾的數據後,我們可以透過在其上訓練新模型來使用它,例如將其與新資料合併一起訓練。

根據我的實驗,一個在蒸餾資料上訓練的模型(每個標籤只包含4個樣本)獲得了66%的最佳性能,而一個完全在原始資料上訓練的模型也是得到了66%的最佳性能。而未經訓練的普通模型得到45%的表現。就像上面提到的這些數字對於密集檢索任務可能不太好,分類資料會好很多。

要使此方法成為在調整模型以適應新資料時值是一個有用的方法,需要能夠提取比原始資料小得多的資料集(即~ 1%)。經過提煉的數據也能夠給你一個略低於或等於主動學習方法的表現。

這個方法的優點是可以創建用於永久使用的蒸餾資料。缺點是提取的數據沒有可解釋性,並且需要額外的訓練時間。

Curriculum/Active training

Curriculum training是一種方法,訓練時向模型提供訓練樣本的難度逐漸變大。在對新資料進行訓練時,此方法需要人工的對任務進行標註,將任務分為簡單、中等或困難,然後對資料進行採樣。為了理解模型的簡單、中等或困難意味著什麼,我以這張圖片為例:

持續學習常用六種方法總結:使ML模型適應新資料的同時保持舊資料的效能

這是在分類任務中的混淆矩陣,困難樣本是假陽性(False Positive),是指模型預測為True的可能性很高,但實際上不是True的樣本。中等樣本是那些具有中到高的正確性可能性但低於預測閾值的True Negative。而簡單樣本則是那些可能性較低的True Positive/Negative。

Maximally Interfered Retrieval

這是 Rahaf 在題為“Online Continual Learning with Maximally Interfered Retrieval”的論文(1908.04742)中介紹的一種方法。主要想法是,對於正在訓練的每個新資料批次,如果針對較新資料更新模型權重,將需要識別在損失值方面受影響最大的舊樣本。保留由舊資料組成的有限大小的內存,並檢索最大干擾的樣本以及每個新資料批次以一起訓練。

這篇論文在持續學習領域是一篇成熟的論文,並且有很多引用,因此可能適用於您的案例。

Retrieval Augmentation

檢索增強(Retrieval Augmentation)是指透過從集合中檢索項目來擴充輸入、樣本等的技術。這是一個普遍的概念而不是一個特定的技術。我們到目前為止所討論的方法,大多數都在一定程度都是檢索相關的操作。 Izacard 的題為 Few-shot Learning with Retrieval Augmented Language Models 的論文使用更小的模型獲得了出色的少樣本 學習的性能。檢索增強也用於許多其他情況,例如單字產生或回答事實問題。

擴展模型在訓練時使用附加層是最常見也最簡單的方法,但是不一定有效,所以在這裡不進行詳細的討論,這裡的一個例子是Lewis 的Efficient Few-Shot Learning without Prompts。使用附加層通常是在新舊數據上獲得良好性能的最簡單但經過嘗試和測試的方法。主要想法是保持模型權重固定,並透過分類損失在新資料上訓練一層或幾層。

總結在本文中,我介紹了在新資料上訓練模型時可以使用的 6 種方法。像往常一樣應該進行實驗並決定哪種方法最適合,但是需要注意的是,除了我上面的方法外還有很多方法,例如數據蒸餾是計算機視覺中的一個活躍領域,你可以找到很多關於它的論文。最後說明的一點是:要使這些方法有價值,它們應該在舊數據和新數據上同時獲得良好的性能 。

以上是持續學習常用六種方法總結:使ML模型適應新資料的同時保持舊資料的效能的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:51CTO.COM。如有侵權,請聯絡admin@php.cn刪除
擁抱面部是否7B型號奧林匹克賽車擊敗克勞德3.7?擁抱面部是否7B型號奧林匹克賽車擊敗克勞德3.7?Apr 23, 2025 am 11:49 AM

擁抱Face的OlympicCoder-7B:強大的開源代碼推理模型 開發以代碼為中心的語言模型的競賽正在加劇,擁抱面孔與強大的競爭者一起參加了比賽:OlympicCoder-7B,一種產品

4個新的雙子座功能您可以錯過4個新的雙子座功能您可以錯過Apr 23, 2025 am 11:48 AM

你們當中有多少人希望AI可以做更多的事情,而不僅僅是回答問題?我知道我有,最近,我對它的變化感到驚訝。 AI聊天機器人不僅要聊天,還關心創建,研究

Camunda為經紀人AI編排編寫了新的分數Camunda為經紀人AI編排編寫了新的分數Apr 23, 2025 am 11:46 AM

隨著智能AI開始融入企業軟件平台和應用程序的各個層面(我們必須強調的是,既有強大的核心工具,也有一些不太可靠的模擬工具),我們需要一套新的基礎設施能力來管理這些智能體。 總部位於德國柏林的流程編排公司Camunda認為,它可以幫助智能AI發揮其應有的作用,並與新的數字工作場所中的準確業務目標和規則保持一致。該公司目前提供智能編排功能,旨在幫助組織建模、部署和管理AI智能體。 從實際的軟件工程角度來看,這意味著什麼? 確定性與非確定性流程的融合 該公司表示,關鍵在於允許用戶(通常是數據科學家、軟件

策劃的企業AI體驗是否有價值?策劃的企業AI體驗是否有價值?Apr 23, 2025 am 11:45 AM

參加Google Cloud Next '25,我渴望看到Google如何區分其AI產品。 有關代理空間(此處討論)和客戶體驗套件(此處討論)的最新公告很有希望,強調了商業價值

如何為抹布找到最佳的多語言嵌入模型?如何為抹布找到最佳的多語言嵌入模型?Apr 23, 2025 am 11:44 AM

為您的檢索增強發電(RAG)系統選擇最佳的多語言嵌入模型 在當今的相互聯繫的世界中,建立有效的多語言AI系統至關重要。 強大的多語言嵌入模型對於RE至關重要

麝香:奧斯汀的機器人需要每10,000英里進行干預麝香:奧斯汀的機器人需要每10,000英里進行干預Apr 23, 2025 am 11:42 AM

特斯拉的Austin Robotaxi發射:仔細觀察Musk的主張 埃隆·馬斯克(Elon Musk)最近宣布,特斯拉即將在德克薩斯州奧斯汀推出的Robotaxi發射,最初出於安全原因部署了一支小型10-20輛汽車,並有快速擴張的計劃。 h

AI震驚的樞軸:從工作工具到數字治療師和生活教練AI震驚的樞軸:從工作工具到數字治療師和生活教練Apr 23, 2025 am 11:41 AM

人工智能的應用方式可能出乎意料。最初,我們很多人可能認為它主要用於代勞創意和技術任務,例如編寫代碼和創作內容。 然而,哈佛商業評論最近報導的一項調查表明情況並非如此。大多數用戶尋求人工智能的並非是代勞工作,而是支持、組織,甚至是友誼! 報告稱,人工智能應用案例的首位是治療和陪伴。這表明其全天候可用性以及提供匿名、誠實建議和反饋的能力非常有價值。 另一方面,營銷任務(例如撰寫博客、創建社交媒體帖子或廣告文案)在流行用途列表中的排名要低得多。 這是為什麼呢?讓我們看看研究結果及其對我們人類如何繼續將

公司競爭AI代理的採用公司競爭AI代理的採用Apr 23, 2025 am 11:40 AM

AI代理商的興起正在改變業務格局。 與雲革命相比,預計AI代理的影響呈指數增長,有望徹底改變知識工作。 模擬人類決策的能力

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱工具

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

MantisBT

MantisBT

Mantis是一個易於部署的基於Web的缺陷追蹤工具,用於幫助產品缺陷追蹤。它需要PHP、MySQL和一個Web伺服器。請查看我們的演示和託管服務。

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

VSCode Windows 64位元 下載

VSCode Windows 64位元 下載

微軟推出的免費、功能強大的一款IDE編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用