Maison >développement back-end >Tutoriel Python >Exemple de code de réglage fin de Huggingface BART : ensemble de données WMT16 pour entraîner de nouvelles balises à traduire

Exemple de code de réglage fin de Huggingface BART : ensemble de données WMT16 pour entraîner de nouvelles balises à traduire

王林
王林avant
2023-04-10 14:41:061349parcourir

Si vous souhaitez tester une nouvelle architecture sur des tâches de traduction, telles que la formation d'un nouveau marqueur sur un ensemble de données personnalisé, cela sera fastidieux à gérer, donc dans cet article, je présenterai les étapes de pré-traitement pour l'ajout de nouveaux marqueurs, et présenter comment affiner le modèle.

Étant donné que Huggingface Hub propose de nombreux modèles pré-entraînés, il est facile de trouver des tagueurs pré-entraînés. Mais il peut être un peu délicat d'ajouter un marqueur. Commençons par charger et prétraiter l'ensemble de données.

Chargement de l'ensemble de données

Nous utilisons l'ensemble de données WMT16 et son sous-ensemble roumain-anglais. La fonction load_dataset() téléchargera et chargera tout ensemble de données disponible depuis Huggingface.

import datasets
 
 dataset = datasets.load_dataset("stas/wmt16-en-ro-pre-processed", cache_dir="./wmt16-en_ro")

Exemple de code de réglage fin de Huggingface BART : ensemble de données WMT16 pour entraîner de nouvelles balises à traduire

Le contenu de l'ensemble de données est visible dans la figure 1 ci-dessus. Nous devons « l'aplatir » pour pouvoir mieux accéder aux données et les enregistrer sur le disque dur.

def flatten(batch):
 batch['en'] = batch['translation']['en']
 batch['ro'] = batch['translation']['ro']
 
 return batch
 
 # Map the 'flatten' function
 train = dataset['train'].map( flatten )
 test = dataset['test'].map( flatten )
 validation = dataset['validation'].map( flatten )
 
 # Save to disk
 train.save_to_disk("./dataset/train")
 test.save_to_disk("./dataset/test")
 validation.save_to_disk("./dataset/validation")

Comme vous pouvez le voir sur la figure 2 ci-dessous, la dimension « traduction » a été supprimée de l'ensemble de données.

Exemple de code de réglage fin de Huggingface BART : ensemble de données WMT16 pour entraîner de nouvelles balises à traduire

Tagger

Tagger fournit tout le travail nécessaire pour former un tokenizer. Il se compose de quatre éléments de base : (mais tous les quatre ne sont pas nécessaires)

Modèles : Comment le tokenizer décomposera chaque mot. Par exemple, étant donné le mot « jouer » : i) le modèle BPE le décompose en deux jetons « jouer » + « ing », ii) WordLevel le traite comme un seul jeton.

Normalisateurs : Quelques transformations qui doivent s'opérer sur le texte. Il existe des filtres pour modifier l'Unicode, les lettres minuscules ou supprimer du contenu.

Pre-Tokenizers : Fonctions qui offrent une plus grande flexibilité pour manipuler le texte. Par exemple, comment travailler avec les chiffres. Le nombre 100 doit-il être considéré comme « 100 » ou « 1 », « 0 », « 0 » ?

Post-processeurs : les spécificités du post-traitement dépendent du choix du modèle pré-entraîné. Par exemple, ajoutez des jetons [BOS] (début de phrase) ou [EOS] (fin de phrase) à l'entrée BERT.

Le code ci-dessous utilise le modèle BPE, des normaliseurs minuscules et des pré-tokeniseurs vierges. Initialisez ensuite l'objet entraîneur avec les valeurs par défaut, notamment

1. La taille du vocabulaire est de 50265 pour être cohérente avec le tagger anglais de BART

2, telles que et ,

3. Quantités, qui est une liste prédéfinie pour chaque processus de lancement de modèle.

from tokenizers import normalizers, pre_tokenizers, Tokenizer, models, trainers
 
 # Build a tokenizer
 bpe_tokenizer = Tokenizer(models.BPE())
 bpe_tokenizer.normalizer = normalizers.Lowercase()
 bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
 
 trainer = trainers.BpeTrainer(
 vocab_size=50265,
 special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"],
 initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
 )

La dernière étape de l'utilisation de Huggingface consiste à connecter le modèle Trainer et BPE et à transmettre l'ensemble de données. Selon la source des données, différentes fonctions de formation peuvent être utilisées. Nous utiliserons train_from_iterator().

def batch_iterator():
 batch_length = 1000
 for i in range(0, len(train), batch_length):
 yield train[i : i + batch_length]["ro"]
 
 bpe_tokenizer.train_from_iterator( batch_iterator(), length=len(train), trainer=trainer )
 
 bpe_tokenizer.save("./ro_tokenizer.json")

BART Spinner

Maintenant disponible avec le nouveau tagger.

from transformers import AutoTokenizer, PreTrainedTokenizerFast
 
 en_tokenizer = AutoTokenizer.from_pretrained( "facebook/bart-base" );
 ro_tokenizer = PreTrainedTokenizerFast.from_pretrained( "./ro_tokenizer.json" );
 ro_tokenizer.pad_token = en_tokenizer.pad_token
 
 def tokenize_dataset(sample):
 input = en_tokenizer(sample['en'], padding='max_length', max_length=120, truncation=True)
 label = ro_tokenizer(sample['ro'], padding='max_length', max_length=120, truncation=True)
 
 input["decoder_input_ids"] = label["input_ids"]
 input["decoder_attention_mask"] = label["attention_mask"]
 input["labels"] = label["input_ids"]
 
 return input
 
 train_tokenized = train.map(tokenize_dataset, batched=True)
 test_tokenized = test.map(tokenize_dataset, batched=True)
 validation_tokenized = validation.map(tokenize_dataset, batched=True)

Ligne 5 du code ci-dessus, définir la balise padding pour le tagueur roumain est très nécessaire. Comme il sera utilisé à la ligne 9, le tokenizer utilise un remplissage pour que toutes les entrées aient la même taille.

Voici le processus de formation :

from transformers import BartForConditionalGeneration
 from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
 
 model = BartForConditionalGeneration.from_pretrained("facebook/bart-base" )
 
 training_args = Seq2SeqTrainingArguments(
 output_dir="./",
 evaluation_strategy="steps",
 per_device_train_batch_size=2,
 per_device_eval_batch_size=2,
 predict_with_generate=True,
 logging_steps=2,# set to 1000 for full training
 save_steps=64,# set to 500 for full training
 eval_steps=64,# set to 8000 for full training
 warmup_steps=1,# set to 2000 for full training
 max_steps=128, # delete for full training
 overwrite_output_dir=True,
 save_total_limit=3,
 fp16=False, # True if GPU
 )
 
 trainer = Seq2SeqTrainer(
 model=model,
 args=training_args,
 train_dataset=train_tokenized,
 eval_dataset=validation_tokenized,
 )
 
 trainer.train()

Le processus est également très simple. Chargez le modèle de base Bart (ligne 4), définissez les paramètres de formation (ligne 6), utilisez l'objet Trainer pour tout lier (ligne 22), et démarrez le processus (ligne 29). Les hyperparamètres ci-dessus sont destinés à des fins de test, donc si vous souhaitez obtenir les meilleurs résultats, vous devez définir les hyperparamètres. Nous pouvons exécuter en utilisant ces paramètres.

Inférence

Le processus d'inférence est également très simple. Il suffit de charger le modèle affiné et d'utiliser la méthode generate() pour convertir. Cependant, il est important d'utiliser des tokenizers appropriés pour la source (En) et la cible (RO). séquences.

Résumé

Bien que la tokenisation puisse sembler une opération de base lors de l'utilisation du traitement du langage naturel (NLP), il s'agit d'une étape critique à ne pas négliger. L'émergence de HuggingFace nous facilite son utilisation, ce qui nous permet d'oublier facilement les principes de base de la tokenisation et de nous appuyer uniquement sur des modèles pré-entraînés. Mais lorsque l’on souhaite former soi-même un nouveau modèle, comprendre le processus de tokenisation et son impact sur les tâches en aval est essentiel, il est donc nécessaire de se familiariser et de maîtriser cette opération de base.

Code pour cet article : https://github.com/AlaFalaki/tutorial_notebooks/blob/main/translation/hf_bart_translation.ipynb

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer