首頁 >科技週邊 >人工智慧 >在少樣本學習中,以SetFit進行文字分類

在少樣本學習中,以SetFit進行文字分類

王林
王林轉載
2023-11-28 11:14:281606瀏覽

译者 | 陈峻

审校 | 重楼

在本文中,我将向您介绍“少样本(Few-shot)学习”的相关概念,并重点讨论被广泛应用于文本分类的SetFit方法。

在少樣本學習中,以SetFit進行文字分類

传统的机器学习(ML)

在监督(Supervised)机器学习中,大量数据集被用于模型训练,以便磨练模型能够做出精确预测的能力。在完成训练过程之后,我们便可以利用测试数据,来获得模型的预测结果。然而,这种传统的监督学习方法存在着一个显著缺点:它需要大量无差错的训练数据集。但是并非所有领域都能够提供此类无差错数据集。因此,“少样本学习”的概念应运而生。

在深入研究Sentence Transformer fine-tuningSetFit)之前,我们有必要简要地回顾一下自然语言处理(Natural Language ProcessingNLP)的一个重要方面,也就是:“少样本学习”。

少样本学习

少样本学习是指:使用有限的训练数据集,来训练模型。模型可以从这些被称为支持集的小集合中获取知识。此类学习旨在教会少样本模型,辨别出训练数据中的相同与相异之处。例如,我们并非要指示模型将所给图像分类为猫或狗,而是指示它掌握各种动物之间的共性和区别。可见,这种方法侧重于理解输入数据中的相似点和不同点。因此,它通常也被称为元学习(meta-learning)、或是从学习到学习(learning-to-learn)。

值得一提的是,少样本学习的支持集,也被称为k向(k-way)n样本(n-shot)学习。其中“k”代表支持集里的类别数。例如,在二分类(binary classification)中,k 等于 2。而“n”表示支持集中每个类别的可用样本数。例如,如果正分类有10个数据点,而负分类也有10个数据点,那么 n就等于10。总之,这个支持集可以被描述为双向10样本学习。

既然我们已经对少样本学习有了基本的了解,下面让我们通过使用SetFit进行快速学习,并在实际应用中对电商数据集进行文本分类。

SetFit架构

Hugging Face和英特尔实验室的团队联合开发的SetFit,是一款用于少样本照片分类的开源工具。你可以在项目库链接--https://github.com/huggingface/setfit?ref=hackernoon.com中,找到关于SetFit的全面信息。

就输出而言,SetFit仅用到了客户评论(Customer Reviews,CR)情感分析数据集里、每个类别的八个标注示例。其结果就能够与由三千个示例组成的完整训练集上,经调优的RoBERTa Large的结果相同。值得强调的是,就体积而言,经微优的RoBERTa模型比SetFit模型大三倍。下图展示的是SetFit架构:

在少樣本學習中,以SetFit進行文字分類

图片来源:https://www.php.cn/link/2456b9cd2668fa69e3c7ecd6f51866bf

SetFit实现快速学习

SetFit的训练速度非常快,效率也极高。与GPT-3T-FEW等大模型相比,其性能极具竞争力。请参见下图:

在少樣本學習中,以SetFit進行文字分類SetFit與T-Few 3B模型的比較

#如下圖所示,##SetFit在少樣本學習的表現優於RoBERTa

在少樣本學習中,以SetFit進行文字分類

SetFit與RoBERT的比較,圖片來源:##https://www .php.cn/link/3ff4cea152080fd7d692a8286a587a67

##。 ,我們將用到由四個不同類別組成的獨特電商資料集,它們分別是:書籍、服裝與配件、電子產品、以及家居用品。此資料集的主要目的是將電商網站的產品描述歸類到指定的標籤下。

為了方便採用少樣本的訓練方法,我們將從四個類別中各選擇八個樣本,從而得到總共32 個訓練樣本。而其餘樣本則將留作測試之用。簡言之,我們在此使用的支持集是48

樣本學習。下圖展示的是自訂電商資料集的範例:

在少樣本學習中,以SetFit進行文字分類

#自訂電商資料集樣本 #我們採用名為「all-mpnet-base-v2」的Sentence Transformers預訓練模型,將文字資料轉換為各種向量嵌入。此模型可以為輸入文本,產生維度為768

的向量嵌入。

如下指令所示,我們將透過在conda環境(是一個開源的軟體包管理系統和環境管理系統)中安裝所需的軟體包,來開始SetFit

的實作。

!pip3 install SetFit !pip3 install sklearn !pip3 install transformers !pip3 install sentence-transformers

安裝完軟體包後,我們便可以透過以下程式碼載入資料集了。

from datasets import load_datasetdataset = load_dataset('csv', data_files={"train": 'E_Commerce_Dataset_Train.csv',"test": 'E_Commerce_Dataset_Test.csv'})

我們來參考下圖,看看訓練樣本和測試樣本數。

在少樣本學習中,以SetFit進行文字分類

訓練與測試資料#我們使用##sklearn軟體包中的LabelEncoder,將文字標籤轉換為編碼標籤。

from sklearn.preprocessing import LabelEncoder le = LabelEncoder()
透過LabelEncoder,我們將對訓練和測試資料集進行編碼,並將編碼後的標籤新增到資料集的「標籤」列。請參閱如下程式碼:

Encoded_Product = le.fit_transform(dataset["train"]['Label']) dataset["train"] = dataset["train"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["train"].features)Encoded_Product = le.fit_transform(dataset["test"]['Label']) dataset["test"] = dataset["test"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["test"].features)
下面,我們將初始化##SetFit模型和句子轉換器(sentence- transformers)模型。

from setfit import SetFitModel, SetFitTrainer from sentence_transformers.losses import CosineSimilarityLossmodel_id = "sentence-transformers/all-mpnet-base-v2" model = SetFitModel.from_pretrained(model_id)trainer = SetFitTrainer(  model=model, train_dataset=dataset["train"], eval_dataset=dataset["test"], loss_class=CosineSimilarityLoss, metric="accuracy", batch_size=64, num_iteratinotallow=20, num_epochs=2, column_mapping={"Text": "text", "Label": "label"})
初始化完成兩個模型後,我們現在可以呼叫訓練程式了。

trainer.train()
在完成了2個訓練輪數(epoch)後,我們將在 eval_dataset上,對訓練好的模型進行評估。

trainer.evaluate()
經過測試,我們的訓練模型的最高準確率為87.5%#。雖然87.5%的準確率並不算高,但畢竟我們的模型只用了32個樣本來訓練。也就是說,考慮到資料集規模的有限性,在測試資料集上取得87.5%的準確率,實際上是相當可觀的。

此外,SetFit也能夠將訓練好的模型,儲存到本地記憶體中,以便後續從磁碟加載,用於將來的預測。

trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/")model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/", local_files_notallow=True)
如下程式碼展示了根據新的資料進行的預測結果:

input = ["Campus Sutra Men's Sports Jersey T-Shirt Cool-Gear: Our Proprietary Moisture Management technology. Helps to absorb and evaporate sweat quickly. Keeps you Cool & Dry. Ultra-Fresh: Fabrics treated with Ultra-Fresh Antimicrobial Technology. Ultra-Fresh is a trademark of (TRA) Inc, Ontario, Canada. Keeps you odour free."]output = model(input)
可見,其預測輸出為1,而標籤的LabelEncoded值為「服裝與配件」。由於傳統的AI模型需要大量的訓練資源(包括時間和資料),才能有穩定水準的輸出。而我們的模型與之相比,既準確又有效率。

######

至此,相信您已經基本掌握了「少樣本學習」的概念,以及如何使用SetFit來進行文字分類等應用。當然,為了獲得更深刻的理解,我強烈建議您選擇一個實際場景,建立一個資料集,編寫對應的程式碼,並將該過程延展到零樣本學習、以及單樣本學習。

譯者介紹

#陳峻(Julian Chen)是51CTO社群的編輯,他在IT專案實施方面有十多年的經驗,擅長管理內部和外部資源和風險,並專注於傳播網路和資訊安全的知識和經驗

原文標題:Mastering Few-Shot Learning with SetFit for Text Classification,作者:Shyam Ganesh S)


以上是在少樣本學習中,以SetFit進行文字分類的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除