ホームページ >テクノロジー周辺機器 >AI >継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させる

継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させる

PHPz
PHPz転載
2023-04-11 23:25:061484ブラウズ

継続学習とは、以前のタスクから得た知識を忘れることなく、多数のタスクを順番に学習するモデルを指します。教師あり学習では、機械学習モデルが特定のデータセットまたはデータ分布に対して最適な関数になるようにトレーニングされるため、これは重要な概念です。実際の環境では、データが静的であることはほとんどなく、変化する可能性があります。一般的な ML モデルは、目に見えないデータに直面するとパフォーマンスが低下する可能性があります。この現象は壊滅的忘却と呼ばれます。

継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させる

#このタイプの問題を解決する一般的な方法は、古いデータと新しいデータの両方を含む新しい、より大きなデータセットでモデル全体を再トレーニングすることです。しかし、このアプローチにはコストがかかることがよくあります。この問題を調査している ML 研究分野があり、この分野の研究に基づいて、この記事では、古いデータのパフォーマンスを維持しながらモデルを新しいデータに適応させ、データの劣化を回避する 6 つの方法について説明します。データセット全体でトレーニングする必要があります (古い新しい) 再トレーニング。

プロンプト

プロンプト このアイデアは、GPT 3 のヒント (短い単語のシーケンス) がモデルをより適切に推論し、回答するのに役立つという考えから生まれました。したがって、この記事では Prompt をプロンプトと訳します。ヒント チューニングとは、学習可能な小さなヒントを使用し、それらを実際の入力とともにモデルへの入力としてフィードすることを指します。これにより、モデルの重みを再トレーニングすることなく、新しいデータに関するヒントを提供する小さなモデルのみをトレーニングすることができます。

具体的には、Wang の記事「継続的学習のためのプロンプトの学習」から引用した、テキストベースの集中的な検索にプロンプ​​トを使用する例を選択しました。

論文の著者は、次の図を使用して自分たちのアイデアを説明しています。

継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させる

実際のエンコードされたテキスト入力は、プロンプトから最小一致ペアを識別するために使用されます。プールの鍵。これらの識別されたキューは、モデルにフィードされる前に、まずエンコードされていないテキスト埋め込みに追加されます。この目的は、古いモデルを変更せずに、新しいタスクを表すようにこれらのキューをトレーニングすることです。ここでのキューは非常に小さく、おそらくプロンプトあたりわずか 20 トークンです。

class PromptPool(nn.Module):
def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):
super().__init__()
self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()
self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()
 
self.length = length
self.hidden = hidden_size
self.n = N
 
nn.init.xavier_normal_(self.pool)
nn.init.xavier_normal_(self.keys)
 
def init_weights(self, embedding):
pass
 
# function to select from pool based on index
def concat(self, indices, input_embeds):
subset = self.pool[indices, :] # 2, 2, 20, 768
 
subset = subset.to("cuda:0").reshape(indices.size(0),
self.n*self.length,
self.hidden) # 2, 40, 768
 
return torch.cat((subset, input_embeds), 1)
 
# x is cls output
def query_fn(self, x):
 
# encode input x to same dim as key using cosine
x = x / x.norm(dim=1)[:, None]
k = self.keys / self.keys.norm(dim=1)[:, None]
 
scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))
 
# get argmin
subsets = torch.topk(scores, self.n, 1, False).indices # k smallest
 
return subsets
 
 pool = PromptPool()

次に、トレーニング済みの古いデータ モデルを使用して新しいデータをトレーニングします。ここでは、プロンプト部分の重みのみをトレーニングします。

def train():
count = 0
print("*********** Started Training *************")
 
start = time.time()
for epoch in range(40):
model.eval()
pool.train()
 
optimizer.zero_grad(set_to_none=True)
lap = time.time()
 
for batch in iter(train_dataloader):
count += 1
q, p, train_labels = batch
 
queries_emb = model(input_ids=q['input_ids'].to("cuda:0"),
attention_mask=q['attention_mask'].to("cuda:0"))
passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),
attention_mask=p['attention_mask'].to("cuda:0"))
 
# pool
q_idx = pool.query_fn(queries_emb)
raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0"))
q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)
 
p_idx = pool.query_fn(passage_emb)
raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0"))
p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)
 
qattention_mask = torch.ones(batch_size, q.size(1))
pattention_mask = torch.ones(batch_size, p.size(1))
 
queries_emb = model.model(inputs_embeds=q,
attention_mask=qattention_mask.to("cuda:0")).last_hidden_state
passage_emb = model.model(inputs_embeds=p,
attention_mask=pattention_mask.to("cuda:0")).last_hidden_state
 
q_cls = queries_emb[:, pool.n*pool.length+1, :]
p_cls = passage_emb[:, pool.n*pool.length+1, :]
 
loss, ql, pl = calc_loss(q_cls, p_cls)
loss.backward()
 
optimizer.step()
optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Model Loss:", round(loss.item(),4), 
"| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), 
"| Took:", round(time.time() - lap), "secondsn")
 
lap = time.time()
 
if count % 40 == 0 and count > 0:
print("model saved")
torch.save(model.state_dict(), model_PATH)
torch.save(pool.state_dict(), pool_PATH)
 
if count == 4600: return
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

トレーニングが完了したら、後続の推論プロセスで入力と取得したヒントを組み合わせる必要があります。たとえば、この例のパフォーマンスは、新しいデータ ヒント プールでは -93%、完全な (新旧) トレーニングでは -94% でした。これは元の論文で言及されているパフォーマンスと同様です。ただし、注意点としては、結果はタスクによって異なる可能性があるため、何が最も効果的かを知るために実験を試みる必要があるということです。

この方法を検討する価値があるためには、古いデータで古いモデルのパフォーマンスの 80% を超えるパフォーマンスを維持できなければなりません。また、ヒントは、モデルが新しいデータで良好なパフォーマンスを達成するのにも役立つはずです。

この方法の欠点は、ヒント プールの使用が必要であり、余分な時間がかかることです。これは恒久的な解決策ではありませんが、現時点では実行可能であり、将来的には新しい方法が登場する可能性があります。

データ蒸留

知識蒸留という用語を聞いたことがあるかもしれません。これは、教師モデルの重みを使用して、より小規模なモデルをガイドおよびトレーニングする手法です。データ蒸留も同様に機能し、実際のデータからの重みを使用してデータのより小さなサブセットをトレーニングします。データ セットの主要な信号は洗練され、より小さなデータ セットに凝縮されるため、新しいデータでのトレーニングには、古いパフォーマンスを維持するためにいくつかの洗練されたデータを提供するだけで済みます。

この例では、データ蒸留を高密度検索 (テキスト) タスクに適用します。現在、この分野でこの方法を使用している人は他にいないため、結果は最良ではない可能性がありますが、テキスト分類にこの方法を使用すると、良い結果が得られるはずです。

本質的に、テキスト データ蒸留のアイデアは、Li による「テキスト分類のためのデータ蒸留」という論文から生まれました。この論文は、Wang が画像データを蒸留したデータセット蒸留にインスピレーションを受けています。 Li は、テキスト データの蒸留のタスクを次の図で説明しています。

継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させる

論文によると、最初に蒸留されたデータのバッチがモデルに供給され、その重みが更新されます。次に、更新されたモデルが実際のデータを使用して評価され、信号が蒸留されたデータセットに逆伝播されます。この論文では、8 つの公開ベンチマーク データセットに関する良好な分類結果 (精度 >80%) が報告されています。

提案されたアイデアに従って、いくつかの小さな変更を加え、抽出されたデータのバッチと複数の実データを使用しました。以下は、集中検索トレーニング用の抽出データを作成するコードです。

class DistilledData(nn.Module):
def __init__(self, num_labels, M, q_len=64, hidden_size=768):
super().__init__()
self.num_samples = M
self.q_len = q_len
self.num_labels = num_labels
self.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768
 
# init using model embedding, xavier, or load from state dict
def init_weights(self, model, path=None):
if model:
self.data.requires_grad = False
print("Init weights using model embedding")
raw_embedding = model.model.get_input_embeddings()
soft_embeds = raw_embedding.weight[:, :].clone().detach()
nums = soft_embeds.size(0)
for i1 in range(self.num_labels):
for i2 in range(self.num_samples):
for i3 in range(self.q_len):
random_idx = random.randint(0, nums-1)
self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]
print(self.data.shape)
self.data.requires_grad = True
 
if not path:
nn.init.xavier_normal_(self.data)
else:
distilled_data.load_state_dict(torch.load(path), strict=False)
 
# function to sample a passage and positive sample as in the article, i am doing dense retrieval
def get_sample(self, label):
q_idx = random.randint(0, self.num_samples-1)
sampled_dist_q = self.data[label, q_idx, :, :]
 
p_idx = random.randint(0, self.num_samples-1)
while q_idx == p_idx:
p_idx = random.randint(0, self.num_samples-1)
sampled_dist_p = self.data[label, p_idx, :, :]
 
return sampled_dist_q, sampled_dist_p, q_idx, p_idx

これは、抽出データに信号を抽出するコードです。

def distll_train(chunk_size=32):
count, times = 0, 0
print("*********** Started Training *************")
start = time.time()
lap = time.time()
 
for epoch in range(40):
distilled_data.train()
 
for batch in iter(train_dataloader):
count += 1
# get real query, pos, label, distilled data query, distilled data pos, ... from batch
q, p, train_labels, dq, dp, q_indexes, p_indexes = batch
 
for idx in range(0, dq['input_ids'].size(0), chunk_size):
model.train()
 
with torch.enable_grad():
# train on distiled data first
x1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
q_emb = model(inputs_embeds=x1.to("cuda:0"),
attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()
p_emb = model(inputs_embeds=x2.to("cuda:0"),
attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))
loss = default_loss(q_emb.to("cuda:0"), p_emb)
del q_emb, p_emb
 
loss.backward(retain_graph=True, create_graph=False)
state_dict = model.state_dict()
 
# update model weights
with torch.no_grad():
for idx, param in enumerate(model.parameters()):
if param.requires_grad and not param.grad is None:
param.data -= (param.grad*3e-5)
 
# real data
model.eval()
q_embs = []
p_embs = []
for k in range(0, len(q['input_ids']), chunk_size):
with torch.no_grad():
q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
q_embs.append(q_emb)
p_embs.append(p_emb)
q_embs = torch.cat(q_embs, 0)
p_embs = torch.cat(p_embs, 0)
r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))
del q_embs, p_embs
 
# distill backward
if count % 2 == 0:
d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = q_indexes
else:
d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = p_indexes
loss.detach()
r_loss.detach()
 
grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768
for i, k in enumerate(indexes):
grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") 
+ d_grad[0][i, :, :]
distilled_data.data.grad = grads
data_optimizer.step()
data_optimizer.zero_grad(set_to_none=True)
 
model.load_state_dict(state_dict)
model_optimizer.step()
model_optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", 
round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))
# print()
lap = time.time()
 
if count % 100 == 0:
torch.save(model.state_dict(), model_PATH)
torch.save(distilled_data.state_dict(), distill_PATH)
 
if loss < 0.1 and r_loss < 1:
times += 1
 
if times > 100:
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")
return
del loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dict
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

ここでは、データのロードなどのコードは省略しています。そして、抽出されたデータがトレーニングされます。最後に、その上で新しいモデルをトレーニングすることによって、たとえば新しいデータと一緒にトレーニングすることによって、それを使用できます。

私の実験によると、蒸留データ (ラベルあたり 4 つのサンプルのみを含む) でトレーニングされたモデルは 66% の最高のパフォーマンスを達成しましたが、完全に生データでトレーニングされたモデルも 66% の最高のパフォーマンスを達成しました。トレーニングされていない通常のモデルは 45% のパフォーマンスを達成しました。上で述べたように、これらの数値は集中的な検索タスクには適さない可能性がありますが、カテゴリ データでははるかに優れています。

新しいデータに合わせてモデルを調整するときにこの方法が役立つためには、元のデータよりもはるかに小さいデータセット (つまり、~1%) を抽出できる必要があります。洗練されたデータでは、アクティブ ラーニング手法と比較してわずかに低いか同等のパフォーマンスが得られる場合もあります。

この方法の利点は、永続的に使用できる抽出データを作成できることです。欠点は、抽出されたデータが解釈できないため、追加のトレーニング時間が必要になることです。

カリキュラム/アクティブ トレーニング

カリキュラム トレーニングは、トレーニング中にモデルにトレーニング サンプルを提供することがますます困難になる方法です。新しいデータでトレーニングする場合、この方法ではタスクに手動でラベルを付け、タスクを簡単、中程度、または困難に分類してからデータをサンプリングする必要があります。モデルが簡単、中程度、または難しいという意味を理解するために、次の図を例として取り上げます。

継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させる

これは、分類タスクの混同行列です。ハード サンプルは偽陽性 (False Positive) とは、モデルが True である可能性が非常に高いと予測しているものの、実際には True ではないサンプルを指します。中サンプルとは、正しい確率が中から高であるものの、予測しきい値を下回る真陰性のサンプルです。単純なサンプルとは、真陽性/陰性の可能性が低いサンプルです。

最大干渉検索

これは、「最大干渉検索によるオンライン継続学習」というタイトルの論文 (1908.04742) で Rahaf によって紹介された方法です。主な考え方は、トレーニングされるデータの新しいバッチごとに、新しいデータのモデルの重みを更新する場合、損失値の点で最も影響を受ける古いサンプルを特定する必要があるということです。古いデータで構成される限られたサイズのメモリが保持され、最も問題となるサンプルが新しいデータ バッチごとに取得され、一緒にトレーニングされます。

この論文は継続学習の分野で定評のある論文であり、多くの引用があるため、あなたのケースにも当てはまるかもしれません。

検索拡張

検索拡張 (検索拡張) は、コレクションから項目を取得することによって入力やサンプルなどを拡張する技術を指します。これは、特定のテクノロジーではなく一般的な概念です。これまで説明してきたメソッドのほとんどは、ある程度の取得関連の操作です。 Izacard の論文「検索拡張言語モデルによる少数ショット学習」では、小規模なモデルを使用して少数ショット学習で優れたパフォーマンスを実現しています。検索の強化は、単語の生成や事実に関する質問への回答など、他の多くの状況でも使用されます。

モデルを拡張する最も一般的かつ簡単な方法は、トレーニング中に追加のレイヤーを使用することですが、必ずしも効果的であるとは限らないため、ここでは詳しく説明しません。ここでの例は、Lewis の Efficient Few-Shot です。プロンプトのない学習。多くの場合、追加のレイヤーを使用するのが、古いデータと新しいデータで優れたパフォーマンスを得る最も簡単ですが、十分に試行された方法です。主なアイデアは、モデルの重みを固定し、分類損失を伴う新しいデータで 1 つまたは複数のレイヤーをトレーニングすることです。

まとめ この記事では、新しいデータでモデルをトレーニングするときに使用できる 6 つの方法を紹介しました。いつものように、どの方法が最も効果的かを実験して決定する必要がありますが、上で説明した方法以外にも多くの方法があることに注意することが重要です。たとえば、データ蒸留はコンピューター ビジョンの活発な分野であり、それについては論文でたくさん見つけることができます。 。最後の注意: これらのメソッドが価値があるためには、古いデータと新しいデータの両方で良好なパフォーマンスを達成する必要があります。

以上が継続学習の 6 つの一般的な方法の概要: 古いデータのパフォーマンスを維持しながら、ML モデルを新しいデータに適応させるの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事は51cto.comで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。