Heim >Technologie-Peripheriegeräte >KI >Einführung in fünf Stichprobenmethoden bei Aufgaben zur Generierung natürlicher Sprache und bei der Implementierung von Pytorch-Code

Einführung in fünf Stichprobenmethoden bei Aufgaben zur Generierung natürlicher Sprache und bei der Implementierung von Pytorch-Code

WBOY
WBOYnach vorne
2024-02-20 08:50:031089Durchsuche

Bei Aufgaben zur Generierung natürlicher Sprache ist die Stichprobenmethode eine Technik, um eine Textausgabe aus einem generativen Modell zu erhalten. In diesem Artikel werden fünf gängige Methoden erläutert und mit PyTorch implementiert.

1. Greedy Decoding

Bei der Greedy Decoding sagt das generative Modell die Wörter der Ausgabesequenz basierend auf der Eingabesequenz Zeit für Zeit voraus. In jedem Zeitschritt berechnet das Modell die bedingte Wahrscheinlichkeitsverteilung jedes Wortes und wählt dann das Wort mit der höchsten bedingten Wahrscheinlichkeit als Ausgabe des aktuellen Zeitschritts aus. Dieses Wort wird zur Eingabe für den nächsten Zeitschritt und der Generierungsprozess wird fortgesetzt, bis eine Abschlussbedingung erfüllt ist, beispielsweise eine Sequenz mit einer bestimmten Länge oder eine spezielle Endmarkierung. Das Merkmal von Greedy Decoding besteht darin, dass jedes Mal das Wort mit der höchsten aktuellen bedingten Wahrscheinlichkeit als Ausgabe ausgewählt wird, ohne die globale optimale Lösung zu berücksichtigen. Diese Methode ist einfach und effizient, kann jedoch zu generierten Sequenzen führen, die weniger genau oder vielfältig sind. Greedy Decoding eignet sich für einige einfache Sequenzgenerierungsaufgaben, bei komplexen Aufgaben müssen jedoch möglicherweise komplexere Decodierungsstrategien verwendet werden, um die Qualität der Generierung zu verbessern.

Obwohl diese Methode schneller berechnet wird, da sich die gierige Dekodierung nur auf die lokal optimale Lösung konzentriert, kann es dazu führen, dass der generierte Text nicht diversifiziert oder ungenau ist und die globale optimale Lösung nicht erhalten werden kann.

Obwohl die gierige Dekodierung ihre Grenzen hat, wird sie bei vielen Aufgaben zur Sequenzgenerierung immer noch häufig verwendet, insbesondere wenn eine schnelle Ausführung erforderlich ist oder die Aufgabe relativ einfach ist.

 def greedy_decoding(input_ids, max_tokens=300): with torch.inference_mode(): for _ in range(max_tokens): outputs = model(input_ids) next_token_logits = outputs.logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) if next_token == tokenizer.eos_token_id: break input_ids = torch.cat([input_ids, rearrange(next_token, 'c -> 1 c')], dim=-1) generated_text = tokenizer.decode(input_ids[0]) return generated_text

2. Beam Search

Beam Search ist eine Erweiterung der Greedy-Dekodierung, die das lokale Optimalproblem der Greedy-Dekodierung überwindet, indem bei jedem Zeitschritt mehrere Kandidatensequenzen beibehalten werden.

Beam-Suche ist eine Methode zum Generieren von Text, die in jedem Zeitschritt die Kandidatenwörter mit der höchsten Wahrscheinlichkeit behält und dann auf der Grundlage dieser Kandidatenwörter im nächsten Zeitschritt bis zum Ende der Generierung weiter erweitert wird. Diese Methode kann die Vielfalt des generierten Texts verbessern, indem mehrere mögliche Wortpfade berücksichtigt werden.

Bei der Strahlsuche generiert das Modell mehrere Kandidatensequenzen gleichzeitig, anstatt nur eine beste Sequenz auszuwählen. Es sagt mögliche Wörter im nächsten Zeitschritt basierend auf der aktuell generierten Teilsequenz und den verborgenen Zuständen voraus und berechnet die bedingte Wahrscheinlichkeitsverteilung jedes Wortes. Diese Methode zur parallelen Generierung mehrerer Kandidatensequenzen trägt zur Verbesserung der Sucheffizienz bei, sodass das Modell schneller die Sequenz mit der höchsten Gesamtwahrscheinlichkeit finden kann.

Einführung in fünf Stichprobenmethoden bei Aufgaben zur Generierung natürlicher Sprache und bei der Implementierung von Pytorch-Code

Bei jedem Schritt werden nur die beiden wahrscheinlichsten Pfade beibehalten und die verbleibenden Pfade werden entsprechend der Einstellung von Strahl = 2 verworfen. Dieser Prozess wird fortgesetzt, bis eine Stoppbedingung erfüllt ist. Dies kann die Generierung eines Ende-der-Sequenz-Tokens oder das Erreichen der vom Modell festgelegten maximalen Sequenzlänge sein. Die endgültige Ausgabe ist die Sequenz mit der höchsten Gesamtwahrscheinlichkeit im letzten Satz von Pfaden.

 from einops import rearrange import torch.nn.functional as F  def beam_search(input_ids, max_tokens=100, beam_size=2): beam_scores = torch.zeros(beam_size).to(device) beam_sequences = input_ids.clone() active_beams = torch.ones(beam_size, dtype=torch.bool) for step in range(max_tokens): outputs = model(beam_sequences) logits = outputs.logits[:, -1, :] probs = F.softmax(logits, dim=-1) top_scores, top_indices = torch.topk(probs.flatten(), k=beam_size, sorted=False) beam_indices = top_indices // probs.shape[-1] token_indices = top_indices % probs.shape[-1] beam_sequences = torch.cat([ beam_sequences[beam_indices], token_indices.unsqueeze(-1)], dim=-1) beam_scores = top_scores active_beams = ~(token_indices == tokenizer.eos_token_id) if not active_beams.any(): print("no active beams") break best_beam = beam_scores.argmax() best_sequence = beam_sequences[best_beam] generated_text = tokenizer.decode(best_sequence) return generated_text

3. Temperatur-Sampling

Temperatur-Parameter-Sampling (Temperatur-Sampling) wird häufig in wahrscheinlichkeitsbasierten generativen Modellen wie Sprachmodellen verwendet. Es steuert die Vielfalt des generierten Textes durch die Einführung eines Parameters namens „Temperatur“, um die Wahrscheinlichkeitsverteilung der Modellausgabe anzupassen.

Beim Temperaturparameter-Sampling berechnet das Modell, wenn es in jedem Zeitschritt Wörter generiert, die bedingte Wahrscheinlichkeitsverteilung der Wörter. Anschließend dividiert das Modell den Wahrscheinlichkeitswert jedes Wortes in dieser bedingten Wahrscheinlichkeitsverteilung durch den Temperaturparameter, normalisiert das Ergebnis und erhält eine neue normalisierte Wahrscheinlichkeitsverteilung. Höhere Temperaturwerte machen die Wahrscheinlichkeitsverteilung glatter und erhöhen so die Vielfalt des generierten Textes. Wörter mit geringer Wahrscheinlichkeit haben auch eine höhere Wahrscheinlichkeit, ausgewählt zu werden; ein niedrigerer Temperaturwert führt dazu, dass die Wahrscheinlichkeitsverteilung konzentrierter wird und die Wahrscheinlichkeit höher ist, dass Wörter mit hoher Wahrscheinlichkeit ausgewählt werden, sodass der generierte Text deterministischer ist. Schließlich führt das Modell eine zufällige Stichprobe gemäß dieser neuen normalisierten Wahrscheinlichkeitsverteilung durch und wählt die generierten Wörter aus.

 import torch import torch.nn.functional as F  def temperature_sampling(logits, temperature=1.0): logits = logits / temperature probabilities = F.softmax(logits, dim=-1) sampled_token = torch.multinomial(probabilities, 1) return sampled_token.item()

4. Top-K-Stichprobe

Top-K-Stichprobe (wählen Sie in jedem Zeitschritt die besten K-Wörter mit bedingter Wahrscheinlichkeit aus und probieren Sie dann zufällig aus diesen K-Wörtern aus. Diese Methode kann eine bestimmte Qualität der Generierung aufrechterhalten kann auch die Vielfalt des Textes erhöhen, und die Vielfalt des generierten Textes kann durch Begrenzen der Anzahl der Kandidatenwörter gesteuert werden Es besteht immer noch ein gewisses Maß an Konkurrenz zwischen den Kandidatenwörtern. Der Parameter K steuert die Anzahl der Kandidatenwörter, die bei jedem Zeitschritt beibehalten werden. Ein kleinerer K-Wert führt zu einem gierigeren Verhalten, da nur wenige Wörter an der Zufallsstichprobe teilnehmen. und ein größerer K-Wert erhöht die Vielfalt des generierten Textes, erhöht aber auch den Rechenaufwand

 def top_k_sampling(input_ids, max_tokens=100, top_k=50, temperature=1.0):for _ in range(max_tokens): with torch.inference_mode(): outputs = model(input_ids) next_token_logits = outputs.logits[:, -1, :] top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) top_k_probs = F.softmax(top_k_logits / temperature, dim=-1) next_token_index = torch.multinomial(top_k_probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token_index) input_ids = torch.cat([input_ids, next_token], dim=-1) generated_text = tokenizer.decode(input_ids[0]) return generated_text
.

5、Top-P (Nucleus) Sampling:

Nucleus Sampling(核采样),也被称为Top-p Sampling旨在在保持生成文本质量的同时增加多样性。这种方法可以视作是Top-K Sampling的一种变体,它在每个时间步根据模型输出的概率分布选择概率累积超过给定阈值p的词语集合,然后在这个词语集合中进行随机采样。这种方法会动态调整候选词语的数量,以保持一定的文本多样性。

Einführung in fünf Stichprobenmethoden bei Aufgaben zur Generierung natürlicher Sprache und bei der Implementierung von Pytorch-Code

在Nucleus Sampling中,模型在每个时间步生成词语时,首先按照概率从高到低对词汇表中的所有词语进行排序,然后模型计算累积概率,并找到累积概率超过给定阈值p的最小词语子集,这个子集就是所谓的“核”(nucleus)。模型在这个核中进行随机采样,根据词语的概率分布来选择最终输出的词语。这样做可以保证所选词语的总概率超过了阈值p,同时也保持了一定的多样性。

参数p是Nucleus Sampling中的重要参数,它决定了所选词语的概率总和。p的值会被设置在(0,1]之间,表示词语总概率的一个下界。

Nucleus Sampling 能够保持一定的生成质量,因为它在一定程度上考虑了概率分布。通过选择概率总和超过给定阈值p的词语子集进行随机采样,Nucleus Sampling 能够增加生成文本的多样性。

 def top_p_sampling(input_ids, max_tokens=100, top_p=0.95): with torch.inference_mode(): for _ in range(max_tokens): outputs = model(input_ids) next_token_logits = outputs.logits[:, -1, :] sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) sorted_probabilities = F.softmax(sorted_logits, dim=-1)  cumulative_probs = torch.cumsum(sorted_probabilities, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 0] = False  indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits.scatter_(-1, indices_to_remove[None, :], float('-inf')) probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=-1) generated_text = tokenizer.decode(input_ids[0]) return generated_text

总结

自然语言生成任务中,采样方法是非常重要的。选择合适的采样方法可以在一定程度上影响生成文本的质量、多样性和效率。上面介绍的几种采样方法各有特点,适用于不同的应用场景和需求。

贪婪解码是一种简单直接的方法,适用于速度要求较高的情况,但可能导致生成文本缺乏多样性。束搜索通过保留多个候选序列来克服贪婪解码的局部最优问题,生成的文本质量更高,但计算开销较大。Top-K 采样和核采样可以控制生成文本的多样性,适用于需要平衡质量和多样性的场景。温度参数采样则可以根据温度参数灵活调节生成文本的多样性,适用于需要平衡多样性和质量的任务。

Das obige ist der detaillierte Inhalt vonEinführung in fünf Stichprobenmethoden bei Aufgaben zur Generierung natürlicher Sprache und bei der Implementierung von Pytorch-Code. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen