Heim >Technologie-Peripheriegeräte >KI >Zusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten

Zusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten

PHPz
PHPznach vorne
2023-04-11 23:25:061469Durchsuche

Kontinuierliches Lernen bezieht sich auf ein Modell, das eine große Anzahl von Aufgaben nacheinander lernt, ohne das aus früheren Aufgaben gewonnene Wissen zu vergessen. Dies ist ein wichtiges Konzept, da beim überwachten Lernen Modelle für maschinelles Lernen darauf trainiert werden, die beste Funktion für einen bestimmten Datensatz oder eine bestimmte Datenverteilung zu bieten. In realen Umgebungen sind Daten selten statisch und können sich ändern. Bei typischen ML-Modellen kann es zu Leistungseinbußen kommen, wenn sie mit unsichtbaren Daten konfrontiert werden. Dieses Phänomen nennt man katastrophales Vergessen.

Zusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten

Eine übliche Methode zur Lösung dieser Art von Problemen besteht darin, das gesamte Modell anhand eines neuen größeren Datensatzes neu zu trainieren, der alte und neue Daten enthält. Doch dieser Ansatz ist oft kostspielig. Es gibt also einen Bereich der ML-Forschung, der sich mit diesem Problem befasst. Basierend auf der Forschung in diesem Bereich werden in diesem Artikel sechs Methoden erörtert, mit denen sich das Modell an neue Daten anpassen kann, während die alte Leistung erhalten bleibt und die Notwendigkeit einer Durchführung vermieden wird gesamter Datensatz (alt + neu) neu trainiert werden.

Prompt

Prompt Die Idee basiert auf der Idee, dass Hinweise (kurze Wortfolgen) in GPT 3 dazu beitragen können, dass Modelle besser schlussfolgern und antworten. Daher wird „Prompt“ in diesem Artikel mit „Prompt“ übersetzt. Unter Hint-Tuning versteht man die Verwendung kleiner lernbarer Hinweise und deren Eingabe als Eingabe in das Modell zusammen mit realen Eingaben. Dadurch können wir nur ein kleines Modell trainieren, das Hinweise auf neue Daten liefert, ohne die Modellgewichte neu zu trainieren.

Konkret habe ich das Beispiel der Verwendung von Eingabeaufforderungen für textbasiertes Intensivabrufen ausgewählt, das aus Wangs Artikel „Learning to Prompt for Continuous Learning“ übernommen wurde.

Die Autoren des Artikels beschreiben ihre Idee anhand des folgenden Diagramms:

Zusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten

Die tatsächliche codierte Texteingabe wird als Schlüssel verwendet, um das kleinste passende Paar aus dem Hinweispool zu identifizieren. Diese identifizierten Hinweise werden zunächst zu unverschlüsselten Texteinbettungen hinzugefügt, bevor sie in das Modell eingespeist werden. Der Zweck besteht darin, diese Eingabeaufforderungen so zu trainieren, dass sie neue Aufgaben darstellen, während das alte Modell unverändert bleibt. Die Eingabeaufforderungen sind hier sehr klein, möglicherweise nur 20 Token pro Eingabeaufforderung.

class PromptPool(nn.Module):
def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):
super().__init__()
self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()
self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()
 
self.length = length
self.hidden = hidden_size
self.n = N
 
nn.init.xavier_normal_(self.pool)
nn.init.xavier_normal_(self.keys)
 
def init_weights(self, embedding):
pass
 
# function to select from pool based on index
def concat(self, indices, input_embeds):
subset = self.pool[indices, :] # 2, 2, 20, 768
 
subset = subset.to("cuda:0").reshape(indices.size(0),
self.n*self.length,
self.hidden) # 2, 40, 768
 
return torch.cat((subset, input_embeds), 1)
 
# x is cls output
def query_fn(self, x):
 
# encode input x to same dim as key using cosine
x = x / x.norm(dim=1)[:, None]
k = self.keys / self.keys.norm(dim=1)[:, None]
 
scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))
 
# get argmin
subsets = torch.topk(scores, self.n, 1, False).indices # k smallest
 
return subsets
 
 pool = PromptPool()

Dann verwenden wir das trainierte alte Datenmodell, um neue Daten zu trainieren. Hier trainieren wir nur das Gewicht des Prompt-Teils.

def train():
count = 0
print("*********** Started Training *************")
 
start = time.time()
for epoch in range(40):
model.eval()
pool.train()
 
optimizer.zero_grad(set_to_none=True)
lap = time.time()
 
for batch in iter(train_dataloader):
count += 1
q, p, train_labels = batch
 
queries_emb = model(input_ids=q['input_ids'].to("cuda:0"),
attention_mask=q['attention_mask'].to("cuda:0"))
passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),
attention_mask=p['attention_mask'].to("cuda:0"))
 
# pool
q_idx = pool.query_fn(queries_emb)
raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0"))
q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)
 
p_idx = pool.query_fn(passage_emb)
raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0"))
p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)
 
qattention_mask = torch.ones(batch_size, q.size(1))
pattention_mask = torch.ones(batch_size, p.size(1))
 
queries_emb = model.model(inputs_embeds=q,
attention_mask=qattention_mask.to("cuda:0")).last_hidden_state
passage_emb = model.model(inputs_embeds=p,
attention_mask=pattention_mask.to("cuda:0")).last_hidden_state
 
q_cls = queries_emb[:, pool.n*pool.length+1, :]
p_cls = passage_emb[:, pool.n*pool.length+1, :]
 
loss, ql, pl = calc_loss(q_cls, p_cls)
loss.backward()
 
optimizer.step()
optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Model Loss:", round(loss.item(),4), 
"| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), 
"| Took:", round(time.time() - lap), "secondsn")
 
lap = time.time()
 
if count % 40 == 0 and count > 0:
print("model saved")
torch.save(model.state_dict(), model_PATH)
torch.save(pool.state_dict(), pool_PATH)
 
if count == 4600: return
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

Nachdem das Training abgeschlossen ist, muss der anschließende Inferenzprozess die Eingabe mit den abgerufenen Hinweisen kombinieren. In diesem Beispiel wurde beispielsweise eine Leistung von -93 % für den neuen Datenhinweispool und -94 % für das vollständige (alte + neue) Training erzielt. Dies ähnelt der im Originalpapier erwähnten Leistung. Die Einschränkung besteht jedoch darin, dass die Ergebnisse je nach Aufgabe variieren können und Sie Experimente durchführen sollten, um herauszufinden, was am besten funktioniert.

Damit diese Methode eine Überlegung wert ist, muss sie in der Lage sein, mehr als 80 % der Leistung des alten Modells bei alten Daten beizubehalten, während die Hinweise dem Modell auch dabei helfen sollten, bei neuen Daten eine gute Leistung zu erzielen.

Der Nachteil dieser Methode besteht darin, dass sie die Verwendung eines Prompt-Pools erfordert, was zusätzliche Zeit kostet. Dies ist keine dauerhafte Lösung, aber vorerst machbar, und vielleicht werden in Zukunft neue Methoden auftauchen.

Datendestillation

Vielleicht haben Sie schon einmal von dem Begriff Wissensdestillation gehört, bei dem es sich um eine Technik handelt, bei der Gewichte aus einem Lehrermodell verwendet werden, um kleinere Modelle anzuleiten und zu trainieren. Die Datendestillation funktioniert ähnlich und verwendet Gewichtungen aus realen Daten, um kleinere Teilmengen der Daten zu trainieren. Da die Schlüsselsignale des Datensatzes verfeinert und zu kleineren Datensätzen zusammengefasst werden, muss unser Training für neue Daten nur mit einigen verfeinerten Daten versehen werden, um die alte Leistung aufrechtzuerhalten.

In diesem Beispiel wende ich die Datendestillation auf eine dichte Abrufaufgabe (Text) an. Derzeit verwendet niemand sonst diese Methode in diesem Bereich, daher sind die Ergebnisse möglicherweise nicht die besten. Wenn Sie diese Methode jedoch zur Textklassifizierung verwenden, sollten Sie gute Ergebnisse erzielen.

Im Wesentlichen stammt die Idee der Textdatendestillation aus einem Artikel von Li mit dem Titel „Data Distillation for Text Classification“, der von Wangs Dataset Distillation inspiriert wurde, bei der er Bilddaten destillierte. Li beschreibt die Aufgabe der Textdatendestillation mit dem folgenden Diagramm:

Zusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten

Dem Papier zufolge wird zunächst eine Charge destillierter Daten in das Modell eingespeist, um dessen Gewichte zu aktualisieren. Das aktualisierte Modell wird dann anhand realer Daten ausgewertet und das Signal wird an den destillierten Datensatz zurückpropagiert. Das Papier berichtet über gute Klassifizierungsergebnisse (>80 % Genauigkeit) für 8 öffentliche Benchmark-Datensätze.

Den vorgeschlagenen Ideen folgend, habe ich einige kleinere Änderungen vorgenommen und eine Reihe destillierter Daten und mehrere reale Daten verwendet. Das Folgende ist der Code zum Erstellen destillierter Daten für ein intensives Abruftraining:

class DistilledData(nn.Module):
def __init__(self, num_labels, M, q_len=64, hidden_size=768):
super().__init__()
self.num_samples = M
self.q_len = q_len
self.num_labels = num_labels
self.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768
 
# init using model embedding, xavier, or load from state dict
def init_weights(self, model, path=None):
if model:
self.data.requires_grad = False
print("Init weights using model embedding")
raw_embedding = model.model.get_input_embeddings()
soft_embeds = raw_embedding.weight[:, :].clone().detach()
nums = soft_embeds.size(0)
for i1 in range(self.num_labels):
for i2 in range(self.num_samples):
for i3 in range(self.q_len):
random_idx = random.randint(0, nums-1)
self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]
print(self.data.shape)
self.data.requires_grad = True
 
if not path:
nn.init.xavier_normal_(self.data)
else:
distilled_data.load_state_dict(torch.load(path), strict=False)
 
# function to sample a passage and positive sample as in the article, i am doing dense retrieval
def get_sample(self, label):
q_idx = random.randint(0, self.num_samples-1)
sampled_dist_q = self.data[label, q_idx, :, :]
 
p_idx = random.randint(0, self.num_samples-1)
while q_idx == p_idx:
p_idx = random.randint(0, self.num_samples-1)
sampled_dist_p = self.data[label, p_idx, :, :]
 
return sampled_dist_q, sampled_dist_p, q_idx, p_idx

Dies ist der Code zum Extrahieren des Signals auf die destillierten Daten

def distll_train(chunk_size=32):
count, times = 0, 0
print("*********** Started Training *************")
start = time.time()
lap = time.time()
 
for epoch in range(40):
distilled_data.train()
 
for batch in iter(train_dataloader):
count += 1
# get real query, pos, label, distilled data query, distilled data pos, ... from batch
q, p, train_labels, dq, dp, q_indexes, p_indexes = batch
 
for idx in range(0, dq['input_ids'].size(0), chunk_size):
model.train()
 
with torch.enable_grad():
# train on distiled data first
x1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
q_emb = model(inputs_embeds=x1.to("cuda:0"),
attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()
p_emb = model(inputs_embeds=x2.to("cuda:0"),
attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))
loss = default_loss(q_emb.to("cuda:0"), p_emb)
del q_emb, p_emb
 
loss.backward(retain_graph=True, create_graph=False)
state_dict = model.state_dict()
 
# update model weights
with torch.no_grad():
for idx, param in enumerate(model.parameters()):
if param.requires_grad and not param.grad is None:
param.data -= (param.grad*3e-5)
 
# real data
model.eval()
q_embs = []
p_embs = []
for k in range(0, len(q['input_ids']), chunk_size):
with torch.no_grad():
q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
q_embs.append(q_emb)
p_embs.append(p_emb)
q_embs = torch.cat(q_embs, 0)
p_embs = torch.cat(p_embs, 0)
r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))
del q_embs, p_embs
 
# distill backward
if count % 2 == 0:
d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = q_indexes
else:
d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = p_indexes
loss.detach()
r_loss.detach()
 
grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768
for i, k in enumerate(indexes):
grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") 
+ d_grad[0][i, :, :]
distilled_data.data.grad = grads
data_optimizer.step()
data_optimizer.zero_grad(set_to_none=True)
 
model.load_state_dict(state_dict)
model_optimizer.step()
model_optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", 
round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))
# print()
lap = time.time()
 
if count % 100 == 0:
torch.save(model.state_dict(), model_PATH)
torch.save(distilled_data.state_dict(), distill_PATH)
 
if loss < 0.1 and r_loss < 1:
times += 1
 
if times > 100:
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")
return
del loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dict
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

Der Code wie das Laden von Daten wird hier weggelassen. Nach dem Training der destillierten Daten können wir Weitergeben Trainieren Sie ein neues Modell, um es zu verwenden, indem Sie es beispielsweise mit neuen Daten kombinieren.

Meinen Experimenten zufolge erzielte ein Modell, das auf destillierten Daten trainiert wurde (mit nur 4 Proben pro Etikett), die beste Leistung von 66 %, während ein Modell, das vollständig auf den Originaldaten trainiert wurde, ebenfalls die beste Leistung von 66 % erzielte. Das untrainierte Normalmodell erreichte eine Leistung von 45 %. Wie oben erwähnt, sind diese Zahlen möglicherweise nicht für intensive Abrufaufgaben geeignet, für kategoriale Daten sind sie jedoch viel besser.

Damit diese Methode bei der Anpassung des Modells an neue Daten nützlich ist, muss man in der Lage sein, einen viel kleineren Datensatz als die Originaldaten zu extrahieren (d. h. ~1 %). Verfeinerte Daten können Ihnen auch eine Leistung bescheren, die etwas geringer oder gleich der Leistung aktiver Lernmethoden ist.

Der Vorteil dieser Methode besteht darin, dass damit destillierte Daten für den dauerhaften Gebrauch erstellt werden können. Der Nachteil besteht darin, dass die extrahierten Daten nicht interpretierbar sind und zusätzliche Trainingszeit erfordern.

Curriculum/Aktives Training

Curriculum-Training ist eine Methode, bei der es nach und nach schwieriger wird, dem Modell während des Trainings Trainingsbeispiele bereitzustellen. Beim Training mit neuen Daten erfordert diese Methode die manuelle Kennzeichnung von Aufgaben, die Klassifizierung von Aufgaben in „leicht“, „mittel“ oder „schwierig“ und die anschließende Stichprobenerhebung der Daten. Um zu verstehen, was es für ein Modell bedeutet, einfach, mittel oder schwer zu sein, nehme ich dieses Bild als Beispiel:

Zusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten

Dies ist die Verwirrungsmatrix in der Klassifizierungsaufgabe, die sich auf das Modell bezieht Vorhersagen Die Wahrscheinlichkeit, wahr zu sein, ist hoch, aber es ist nicht wirklich eine Stichprobe, die wahr ist. Mittlere Stichproben sind solche, bei denen die Wahrscheinlichkeit, dass sie richtig sind, mittel bis hoch ist, die jedoch unterhalb des Vorhersageschwellenwerts richtig negativ sind. Einfache Proben sind Proben mit einer geringeren Wahrscheinlichkeit, dass sie richtig positiv/negativ sind.

Maximally Interfered Retrieval

Dies ist eine Methode, die Rahaf in einem Artikel (1908.04742) mit dem Titel „Online Continual Learning with Maximally Interfered Retrieval“ eingeführt hat. Der Grundgedanke besteht darin, dass Sie für jeden neuen zu trainierenden Datenstapel, wenn Sie die Modellgewichte für neuere Daten aktualisieren, die älteren Stichproben identifizieren müssen, die im Hinblick auf die Verlustwerte am stärksten betroffen sind. Ein begrenzter Speicher bestehend aus alten Daten bleibt erhalten und die störendsten Proben werden zusammen mit jedem neuen Datenstapel abgerufen, um gemeinsam zu trainieren.

Dieses Papier ist ein etabliertes Papier im Bereich des kontinuierlichen Lernens und enthält viele Zitate, sodass es möglicherweise auf Ihren Fall anwendbar ist.

Retrieval Augmentation

Retrieval Augmentation bezieht sich auf die Technik der Erweiterung von Eingaben, Beispielen usw. durch das Abrufen von Elementen aus einer Sammlung. Dabei handelt es sich eher um ein allgemeines Konzept als um eine spezifische Technologie. Bei den meisten der bisher besprochenen Methoden handelt es sich bis zu einem gewissen Grad um abrufbezogene Vorgänge. Izacards Artikel mit dem Titel „Few-shot Learning with Retrieval Augmented Language Models“ verwendet kleinere Modelle, um eine hervorragende Leistung beim Few-shot-Lernen zu erzielen. Retrieval Enhancement wird auch in vielen anderen Situationen eingesetzt, beispielsweise bei der Wortgenerierung oder der Beantwortung von Faktenfragen.

Die Erweiterung des Modells um die Verwendung zusätzlicher Schichten während des Trainings ist die gebräuchlichste und einfachste Methode, sie ist jedoch nicht unbedingt effektiv und wird daher hier nicht im Detail besprochen. Ein Beispiel ist Lewis' Efficient Few-Shot Learning without Prompts. Die Verwendung zusätzlicher Ebenen ist oft die einfachste, aber bewährte Methode, um eine gute Leistung bei alten und neuen Daten zu erzielen. Die Hauptidee besteht darin, die Modellgewichte festzuhalten und eine oder mehrere Schichten auf neue Daten mit Klassifizierungsverlust zu trainieren.

Zusammenfassung In diesem Artikel habe ich 6 Methoden vorgestellt, die Sie beim Training eines Modells anhand neuer Daten verwenden können. Wie immer sollte man experimentieren und entscheiden, welche Methode am besten funktioniert, aber es ist wichtig zu beachten, dass es neben den oben genannten noch viele andere Methoden gibt, zum Beispiel ist die Datendestillation ein aktiver Bereich in der Computer Vision und man kann viel darüber in Papierform finden . Ein letzter Hinweis: Damit diese Methoden wertvoll sind, sollten sie sowohl bei alten als auch bei neuen Daten eine gute Leistung erzielen.

Das obige ist der detaillierte Inhalt vonZusammenfassung von sechs gängigen Methoden des kontinuierlichen Lernens: Anpassung von ML-Modellen an neue Daten bei gleichzeitiger Beibehaltung der Leistung alter Daten. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen