Home >Technology peripherals >AI >LLMs: Transfer Learning with TensorFlow, Keras, Hugging Face
Transfer learning is one of the most powerful techniques in deep learning, especially when working with Large Language Models (LLMs). These models, such as Flan-T5, are pre-trained on vast amounts of data, allowing them to generalize across many language tasks. Instead of training a model from scratch, we can fine-tune these pre-trained models for specific tasks, like question-answering.
In this guide, we will walk you through how to perform transfer learning on Flan-T5-large using TensorFlow and Hugging Face. We’ll fine-tune this model on the SQuAD (Stanford Question Answering Dataset), a popular dataset used to train models for answering questions based on a given context.
Key points we’ll cover include:
Hugging Face is a popular platform and library that simplifies working with powerful models in Natural Language Processing (NLP). The key components include:
With Hugging Face, you don't need to build models from scratch. It offers access to a wide variety of pre-trained models, including BERT, GPT-3, and T5, which significantly reduces the time and resources needed to develop NLP solutions. By leveraging these models, you can quickly fine-tune them for specific downstream tasks like question-answering, text classification, and summarization.
Hugging Face provides various model classes, but AutoModel is one of the most flexible and widely used. The AutoModel API abstracts away the complexities of manually selecting and loading models. You don’t need to know the specific class of each model beforehand; AutoModel will load the correct architecture based on the model's name.
For instance, AutoModelForSeq2SeqLM is used for sequence-to-sequence models like T5 or BART, which are typically used for tasks such as translation, summarization, and question-answering. The beauty of AutoModel is that it is model-agnostic—meaning you can swap out models with ease and still use the same code.
Here’s how it works in practice:
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer# Load the pre-trained Flan-T5-large model and tokenizermodel_name = "google/flan-t5-large"model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) # Load modeltokenizer = AutoTokenizer.from_pretrained(model_name) # Load tokenizer
The AutoModel dynamically loads the correct model architecture based on the model's name (in this case, flan-t5-large). This flexibility makes the development process much smoother and faster because you don’t need to worry about manually specifying each model's architecture.
To understand how T5 works, let's first break down its architecture. T5 stands for Text-to-Text Transfer Transformer, and it was introduced by Google in 2019. The key idea behind T5 is that every NLP task can be cast as a text-to-text problem, whether it's translation, summarization, or even question-answering.
Key Components of T5:
Here’s an example of how T5 might be applied to a question-answering task:
Input: "question: What is T5? context: T5 is a text-to-text transfer transformer developed by Google."Output: "T5 is a text-to-text transfer transformer."
The beauty of T5’s text-to-text framework is its flexibility. You can use the same model architecture for various tasks simply by rephrasing the input. This makes T5 highly versatile and adaptable for a range of NLP tasks.
T5 has been pre-trained on a massive dataset known as C4 (Colossal Clean Crawled Corpus), which gives it a solid understanding of the structure of language. Through transfer learning, we can fine-tune this pre-trained model to specialize in a specific task, such as question-answering with the SQuAD dataset. By leveraging T5’s pre-trained knowledge, we only need to tweak the final layer to make it perform well on our task, which reduces training time and computational resources.
Now that we have the model, we need data to fine-tune it. We'll use the SQuAD dataset, a collection of question-answer pairs based on passages of text.
from datasets import load_dataset# Load the SQuAD datasetsquad = load_dataset("squad") train_data = squad["train"] valid_data = squad["validation"]
The SQuAD dataset is widely used for training models in question-answering tasks. Each data point in the dataset consists of a context (a passage of text), a question, and the corresponding answer, which is a span of text found within the context.
Before feeding the data into the model, we need to tokenize it. Tokenization converts raw text into numerical values (tokens) that the model can understand. For T5, we must format the input as a combination of the question and context.
# Preprocessing function to tokenize inputs and outputsdef preprocess_function(examples): # Combine the question and context into a single string inputs = ["question: " + q + " context: " + c for q, c in zip(examples["question"], examples["context"])] model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length", return_tensors="tf") # Tokenize the answer (label) labels = tokenizer(examples["answers"]["text"][0], max_length=64, truncation=True, padding="max_length", return_tensors="tf") model_inputs["labels"] = labels["input_ids"] return model_inputs# Preprocess the datasettrain_data = train_data.map(preprocess_function, batched=True) valid_data = valid_data.map(preprocess_function, batched=True)
This function tokenizes both the question-context pairs (the input) and the answers (the output). Tokenization is necessary for transforming raw text into tokenized sequences that the model can process.
Here’s where we perform transfer learning. To make fine-tuning efficient, we freeze the encoder and decoder layers, and unfreeze only the final layer. This strategy ensures that the computationally heavy layers are kept intact while allowing the final layer to specialize in the task of answering questions.
from tensorflow.keras.optimizers import Adam# Freeze all layers by default (encoder, decoder, embedding layers)for layer in model.layers: layer.trainable = False# Unfreeze only the final task-specific layermodel.layers[-1].trainable = True# Compile the model with the correct Hugging Face loss function for TensorFlow optimizer = Adam(learning_rate=3e-5) model.compile(optimizer=optimizer, loss=model.hf_compute_loss)# Fine-tune the model on the SQuAD datasetmodel.fit(train_data.shuffle(1000).batch(8), epochs=3, validation_data=valid_data.batch(8))
Explanation:
Once the model is fine-tuned, it’s important to test how well it performs on the validation set.
# Select a sample from the validation setsample = valid_data[0]# Tokenize the input textinput_text = "question: " + sample["question"] + " context: " + sample["context"] input_ids = tokenizer(input_text, return_tensors="tf").input_ids# Generate the output (the model's answer)output = model.generate(input_ids) answer = tokenizer.decode(output[0], skip_special_tokens=True)print(f"Question: {sample['question']}")print(f"Answer: {answer}")
This code takes a sample question-context pair, tokenizes it, and uses the fine-tuned model to generate an answer. The tokenizer decodes the output back into human-readable text.
Although we’ve covered the basics of fine-tuning, there are several ways you can further improve the performance of your model:
In this guide, we walked through the entire process of fine-tuning a pre-trained LLM (Flan-T5-large) using TensorFlow and Hugging Face. By freezing the computationally expensive encoder and decoder layers and only fine-tuning the final layer, we optimized the training process while still adapting the model to our specific task of question-answering on the SQuAD dataset.
T5’s text-to-text framework makes it highly flexible and adaptable to various NLP tasks, and Hugging Face’s AutoModel abstraction simplifies the process of working with these models. By understanding the architecture and principles behind models like T5, you can apply these techniques to a variety of other NLP tasks, making transfer learning a powerful tool in your machine learning toolkit.
The above is the detailed content of LLMs: Transfer Learning with TensorFlow, Keras, Hugging Face. For more information, please follow other related articles on the PHP Chinese website!