Rumah >pembangunan bahagian belakang >Tutorial Python >Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan
Jika anda ingin menguji seni bina baharu pada tugas terjemahan, seperti melatih teg baharu pada set data tersuai, ia akan menyusahkan untuk dikendalikan, jadi dalam artikel ini, saya akan memperkenalkan pra-pemprosesan menambah teg baharu. Memproses langkah dan memperkenalkan cara memperhalusi model.
Oleh kerana Huggingface Hub mempunyai banyak model terlatih, mudah untuk mencari penanda terlatih. Tetapi mungkin agak sukar untuk menambah penanda Mari kita perkenalkan sepenuhnya cara melaksanakannya Mula-mula, muatkan dan praproses set data.
Kami menggunakan set data WMT16 dan subset Romania-Inggerisnya. Fungsi load_dataset() akan memuat turun dan memuatkan mana-mana set data yang tersedia daripada Huggingface.
import datasets dataset = datasets.load_dataset("stas/wmt16-en-ro-pre-processed", cache_dir="./wmt16-en_ro")
Kandungan set data boleh dilihat dalam Rajah 1 di atas. Kita perlu "meratakannya" supaya kita boleh mengakses data dengan lebih baik dan menyimpannya ke cakera keras.
def flatten(batch): batch['en'] = batch['translation']['en'] batch['ro'] = batch['translation']['ro'] return batch # Map the 'flatten' function train = dataset['train'].map( flatten ) test = dataset['test'].map( flatten ) validation = dataset['validation'].map( flatten ) # Save to disk train.save_to_disk("./dataset/train") test.save_to_disk("./dataset/test") validation.save_to_disk("./dataset/validation")
Seperti yang anda lihat dalam Rajah 2 di bawah, dimensi "terjemahan" telah dipadamkan daripada set data.
Tagger menyediakan semua kerja yang diperlukan untuk melatih tokenizer. Ia terdiri daripada empat komponen asas: (tetapi bukan keempat-empatnya diperlukan)
Model: Bagaimana tokenizer akan memecahkan setiap perkataan. Sebagai contoh, diberi perkataan "bermain": i) Model BPE menguraikannya kepada dua token "bermain" + "ing", ii) WordLevel menganggapnya sebagai satu token.
Normalizers: Beberapa transformasi yang perlu berlaku pada teks. Terdapat penapis untuk menukar Unicode, huruf kecil atau mengalih keluar kandungan.
Pra-Tokenizer: Fungsi yang memberikan fleksibiliti yang lebih besar untuk mengendalikan teks. Sebagai contoh, bagaimana untuk bekerja dengan nombor. Sekiranya nombor 100 dianggap "100" atau "1", "0", "0"?
Pos-Pemproses: Spesifik pasca pemprosesan bergantung pada pilihan pra -model terlatih. Sebagai contoh, tambah token [BOS] (permulaan ayat) atau [EOS] (akhir ayat) pada input BERT.
Kod di bawah menggunakan model BPE, Penormal huruf kecil dan Pra-Tokenizer kosong. Kemudian mulakan objek pelatih dengan nilai lalai, terutamanya termasuk
1 Gunakan 50265 untuk saiz perbendaharaan kata agar selaras dengan penanda Bahasa Inggeris BART
2 dan ,
from tokenizers import normalizers, pre_tokenizers, Tokenizer, models, trainers # Build a tokenizer bpe_tokenizer = Tokenizer(models.BPE()) bpe_tokenizer.normalizer = normalizers.Lowercase() bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() trainer = trainers.BpeTrainer( vocab_size=50265, special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"], initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), )BART Spinner
def batch_iterator(): batch_length = 1000 for i in range(0, len(train), batch_length): yield train[i : i + batch_length]["ro"] bpe_tokenizer.train_from_iterator( batch_iterator(), length=len(train), trainer=trainer ) bpe_tokenizer.save("./ro_tokenizer.json")
from transformers import AutoTokenizer, PreTrainedTokenizerFast en_tokenizer = AutoTokenizer.from_pretrained( "facebook/bart-base" ); ro_tokenizer = PreTrainedTokenizerFast.from_pretrained( "./ro_tokenizer.json" ); ro_tokenizer.pad_token = en_tokenizer.pad_token def tokenize_dataset(sample): input = en_tokenizer(sample['en'], padding='max_length', max_length=120, truncation=True) label = ro_tokenizer(sample['ro'], padding='max_length', max_length=120, truncation=True) input["decoder_input_ids"] = label["input_ids"] input["decoder_attention_mask"] = label["attention_mask"] input["labels"] = label["input_ids"] return input train_tokenized = train.map(tokenize_dataset, batched=True) test_tokenized = test.map(tokenize_dataset, batched=True) validation_tokenized = validation.map(tokenize_dataset, batched=True)
Berikut ialah proses latihan:
Proses ini juga sangat mudah Muatkan model asas bart (baris 4), tetapkan parameter latihan (baris 6), dan gunakan objek Jurulatih untuk mengikat segala-galanya (baris 22), dan memulakan proses (baris 29). Hiperparameter di atas adalah untuk tujuan ujian, jadi jika anda ingin mendapatkan hasil yang terbaik, anda perlu menetapkan hiperparameter Kami boleh menjalankan menggunakan parameter ini.from transformers import BartForConditionalGeneration from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer model = BartForConditionalGeneration.from_pretrained("facebook/bart-base" ) training_args = Seq2SeqTrainingArguments( output_dir="./", evaluation_strategy="steps", per_device_train_batch_size=2, per_device_eval_batch_size=2, predict_with_generate=True, logging_steps=2,# set to 1000 for full training save_steps=64,# set to 500 for full training eval_steps=64,# set to 8000 for full training warmup_steps=1,# set to 2000 for full training max_steps=128, # delete for full training overwrite_output_dir=True, save_total_limit=3, fp16=False, # True if GPU ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_tokenized, eval_dataset=validation_tokenized, ) trainer.train()
Inferens
Ringkasan
Kod artikel ini: https://github.com/AlaFalaki/tutorial_notebooks/blob/main/translation/hf_bart_translation.ipynb
Atas ialah kandungan terperinci Contoh kod penalaan halus Huggingface BART: Set data WMT16 untuk melatih teg baharu untuk terjemahan. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!