搜尋
首頁科技週邊人工智慧使用遷移學習技術進行深度學習模型的客製化訓練

使用遷移學習技術進行深度學習模型的客製化訓練

Apr 23, 2023 am 08:13 AM
機器學習數據集遷移學習

譯者| 朱先忠

審校| 孫淑娟

遷移學習是機器學習的一種類型,它是一種應用於已經訓練或預先訓練的神經網路的方法,而這些預先訓練的神經元網路是使用數百萬個資料點訓練出來的。

使用遷移學習技術進行深度學習模型的客製化訓練

該技術目前最著名的用法是用來訓練深度神經網絡,因為這種方法在使用較少的資料訓練深度神經網路時表現出良好的性能。實際上,這種技術在資料科學領域也是很有用的,因為大多數真實世界的資料通常沒有數百萬個資料點來訓練出穩固的深度學習模型。

目前,已經存在許多使用數百萬個資料點訓練出來的模型,並且這些模型可以用於以最大精度來訓練複雜的深度學習神經網路。

在本教程中,您將學習如何使用遷移學習技術來訓練一個深度神經網路的完整過程。

使用Keras程式實現遷移學習

在建立或訓練深度神經網路之前,您必須搞清楚有哪些選擇方案可用於遷移學習,以及必須使用哪一個方案來為專案訓練複雜的深度神經網路。

Keras應用程式是一種高級的深度學習模型,它提供了可用於預測、特徵提取和微調的預訓練權重。 Keras庫中內建提供了許多現成可用的模型,其中一些流行的模型包括:

  • Xception
  • VGG16 and VGG19
  • ResNet Series
  • MobileNet

【補充】Keras應用程式提供了一組深度學習模型,它們可與預先訓練的權重一起使用。有關這些模型的更具體的內容,請參考Keras官網內容

在本文中,您將學習MobileNet模型在遷移學習中的應用。

訓練一個深度學習模型

在本節中,您將學習如何在短短的幾個步驟內為圖像識別建立一個自訂深度學習模型,而無需編寫任何一系列卷積神經網路(CNN),您只需對預訓練的模型加以微調,即可使得您的模型在訓練資料集上進行訓練。

在本文中,我們建立的深度學習模型將能夠辨識手勢語言數字的圖像。接下來,讓我們開始著手建立這個自訂深度學習模型。

取得資料集

要開始建立一個深度學習模型的過程,您首先需要準備好數據,您可以透過造訪一個名為Kaggle的網站,從數百萬個資料集中輕鬆選擇合適的資料集。當然,也存在不少其他網站為建立深度學習或機器學習模型提供可用的資料集。

但本文將使用的資料集取自Kaggle網站提供的美國手語數字資料集

資料預處理

在下載資料集並將其儲存到本機儲存後,現在是時候對資料集執行一些預處理了,例如準備資料、將資料拆分為train目錄、valid目錄和test目錄、定義它們的路徑以及為訓練目的創建批次處理,等等。

準備資料

下載資料集時,它包含從0到9資料的目錄,其中有三個子資料夾分別對應輸入影像、輸出影像以及一個名稱為CSV的資料夾。

接著,從每個目錄中刪除輸出影像和CSV資料夾,將輸入影像資料夾下的內容移至主目錄下,然後刪除輸入影像資料夾。

資料集的每個主目錄現在都擁有500張影像,您可以選擇保留所有影像。但出於演示目的,本文中每個目錄中只使用其中的200幅圖像。

最終,資料集的結構將如下圖所示:

使用遷移學習技術進行深度學習模型的客製化訓練

資料集的資料夾結構

#分割資料集

現在,讓我們從將資料集拆分為train、valid和test三個子目錄開始。

  • train目錄將包含訓練數據,這些數據將作為我們輸入模型的輸入數據,用於學習模式和不規則性。
  • valid目錄將包含將輸入到模型中的驗證數據,並且將是模型所看到的第一個未看到的數據,這將有助於獲得最大的準確性。
  • test目錄將包含用於測試模型的測試資料。

首先,我們來導入將在程式碼中進一步使用的函式庫。

# 导入需要的库
import os
import shutil
import random

以下是產生所需目錄並將資料移至特定目錄的程式碼。

#创建三个子目录:train、valid和test,并把数据组织到其下
os.chdir('D:SACHINJupyterHand Sign LanguageHand_Sign_Language_DL_ProjectAmerican-Sign-Language-Digits-Dataset')

#如果目录不存在则创建相应的子目录
if os.path.isdir('train/0/') is False:
os.mkdir('train')
os.mkdir('valid')
os.mkdir('test')

for i in range(0, 10):
#把0-9子目录移动到train子目录下
shutil.move(f'{i}', 'train')
os.mkdir(f'valid/{i}')
os.mkdir(f'test/{i}')

#从valid子目录下取90个样本图像
valid_samples = random.sample(os.listdir(f'train/{i}'), 90)
for j in valid_samples:
#把样本图像从子目录train移动到valid子目录
shutil.move(f'train/{i}/{j}', f'valid/{i}')

#从test子目录下取90个样本图像
test_samples = random.sample(os.listdir(f'train/{i}'), 10)
for k in test_samples:
#把样本图像从子目录train移动到test子目录
shutil.move(f'train/{i}/{k}', f'test/{i}')

os.chdir('../..')

在上面的程式碼中,我們首先更改了資料集在本機儲存中對應的目錄,然後檢查是否已經存在train/0目錄;如果沒有,我們將分別建立train、valid和test子目錄。

然後,我們建立子目錄0到9,並將所有資料移到train目錄中,同時建立了valid和test這兩個子目錄下各自的子目錄0至9。

然後,我們在train目錄內的子目錄0到9上進行迭代,並從每個子目錄中隨機獲取90個圖像數據,並將它們移動到valid目錄內的相應子目錄。

對於測試目錄test也是如此。

【補充】在Python中執行高級檔案操作的shutil模組(手動將檔案或資料夾從一個目錄複製或移動到另一個目錄可能是一件非常痛苦的事情。有關詳細技巧,請參考文章https://medium.com/@geekpython/perform-high-level-file-operations-in-python-shutil-module-dfd71b149d32)。

定義到各目錄的路徑

建立所需的目錄後,現在需要定義train、valid和test這三個子目錄的路徑。

#为三个子目录train、valid和test分别指定路径
train_path = 'D:/SACHIN/Jupyter/Hand Sign Language/Hand_Sign_Language_DL_Project/American-Sign-Language-Digits-Dataset/train'
valid_path = 'D:/SACHIN/Jupyter/Hand Sign Language/Hand_Sign_Language_DL_Project/American-Sign-Language-Digits-Dataset/valid'
test_path = 'D:/SACHIN/Jupyter/Hand Sign Language/Hand_Sign_Language_DL_Project/American-Sign-Language-Digits-Dataset/test'

進行預處理

前訓練的深度學習模型需要一些預處理的數據,這些數據非常適合訓練。因此,資料需要採用預訓練模型所需的格式。

在應用任何預處理之前,讓我們導入TensorFlow及其實用程序,這些實用程式將在程式碼中進一步使用。

#导入TensorFlow及其实用程序
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model

#创建训练、校验和测试图像的批次,并使用Mobilenet的预处理模型进行预处理
train_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
directory=train_path, target_size=(224,224), batch_size=10, shuffle=True)
valid_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
directory=valid_path, target_size=(224,224), batch_size=10, shuffle=True)
test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
directory=test_path, target_size=(224,224), batch_size=10, shuffle=False)

我們使用了ImageDatagenerator,它採用了一個參數preprocessing_function,在該函數參數中,我們對MobileNet模型提供的映像進行了預處理。

接下來,呼叫flow_from_directory函數,其中我們提供了要訓練圖像的目錄和維度的路徑,因為MobileNet模型是為具有224x224維度的圖像訓練使用的。

再接下來,定義了批次大小-定義一次迭代中可以處理多少個影像,然後我們對影像處理順序進行隨機打亂。在此,我們沒有針對測試資料的影像進行隨機亂序,因為測試資料不會用於訓練。

在Jupyter筆記本或Google Colab中執行上述程式碼片段後,您將看到以下結果。

使用遷移學習技術進行深度學習模型的客製化訓練

上述程式碼的輸出結果

ImageDataGenerator的一般應用場景是用於增廣數據,以下是使用Keras框架中ImageDataGenerator執行資料增廣的指南指南

建立模型

在將訓練和驗證資料擬合到模型中之前,深度學習模型MobileNet需要透過新增輸出層、刪除不必要的層以及使某些層不可訓練,從而獲得更好的準確性來進行微調。

以下程式碼將從Keras下載MobileNet模型並將其儲存在mobile變數中。您需要在第一次執行以下程式碼片段時連接到網際網路。

mobile = tf.keras.applications.mobilenet.MobileNet()

如果您运行以下代码,那么您将看到模型的摘要信息,在其中你可以看到一系列神经网络层的输出信息。

mobile.summary()

现在,我们将在模型中添加以10为单位的全连接输出层(也称“稠密层”)——因为从0到9将有10个输出。此外,我们从MobileNet模型中删除了最后六个层。

# 删除最后6层并添加一个输出层
x = mobile.layers[-6].output
output = Dense(units=10, activation='softmax')(x)

然后,我们将所有输入和输出层添加到模型中。

model = Model(inputs=mobile.input, outputs=output)

现在,我们将最后23层设置成不可训练的——其实这是一个相对随意的数字。一般来说,这一具体数字是通过多次试验和错误获得的。该代码的唯一目的是通过使某些层不可训练来提高精度。

#我们不会训练最后23层——这里的23是一个相对随意的数字
for layer in mobile.layers[:-23]:
layer.trainable=False

如果您看到了微调模型的摘要输出,那么您将注意到与前面看到的原始摘要相比,不可训练参数和层的数量存在一些差异。

model.summary()

接下来,我们要编译名为Adam的优化器,选择学习率为0.0001,以及损失函数,还有衡量模型的准确性的度量参数。

model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

现在是准备好模型并根据训练和验证数据来开始训练的时候了。在下面的代码中,我们提供了训练和验证数据以及训练的总体轮回数。详细信息只是为了显示准确性进度,在这里您可以指定一个数字参数值为0、1或者2。

# 运行共10个轮回(epochs)
model.fit(x=train_batches, validation_data=valid_batches, epochs=10, verbose=2)

如果您运行上面的代码片断,那么您将看到训练数据丢失和准确性的轮回的每一步的输出内容。对于验证数据,您也能够看到这样的输出结果。

使用遷移學習技術進行深度學習模型的客製化訓練

显示有精度值的训练轮回步数

存储模型

该模型现在已准备就绪,准确度得分为99%。现在请记住一件事:这个模型可能存在过度拟合,因此有可能对于给定数据集图像以外的图像表现不佳。

#检查模型是否存在;否则,保存模型
if os.path.isfile("D:/SACHIN/Models/Hand-Sign-Digit-Language/digit_model.h5") is False:
model.save("D:/SACHIN/Models/Hand-Sign-Digit-Language/digit_model.h5")

上面的代码将检查是否已经有模型的副本。如果没有,则通过调用save函数在指定的路径中保存模型。

测试模型

至此,模型已经经过训练,可以用于识别图像了。本节将介绍加载模型和编写准备图像、预测结果以及显示和打印预测结果的函数。

在编写任何代码之前,需要导入一些将在代码中进一步使用的必要的库。

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

加载定制的模型

对图像的预测将使用上面使用迁移学习技术创建的模型进行。因此,我们首先需要加载该模型,以供后面使用。

my_model = load_model("D:/SACHIN/Models/Hand-Sign-Digit-Language/digit_model.h5")

在此,我们通过使用load_model函数,实现从指定路径加载模型,并将其存储在my_model变量中,以便在后面代码中进一步使用。

准备输入图像

在向模型提供任何用于预测或识别的图像之前,我们需要提供模型所需的格式。

def preprocess_img(img_path):
open_img = image.load_img(img_path, target_size=(224, 224))
img_arr = image.img_to_array(open_img)/255.0
img_reshape = img_arr.reshape(1, 224,224,3)
return img_reshape

首先,我们要定义一个获取图像路径的函数preprocess_img,然后使用image实用程序中的load_img函数加载该图像,并将目标大小设置为224x224。然后将该图像转换成一个数组,并将该数组除以255.0,这样就将图像的像素值转换为0和1,然后将图像数组重新调整为形状(224,224,3),最后返回转换形状后的图像。

编写预测函数

def predict_result(predict):
pred = my_model.predict(predict)
return np.argmax(pred[0], axis=-1)

这里,我们定义了一个函数predict_result,它接受predict参数,此参数基本上是一个预处理的图像。然后,我们调用模型的predict函数来预测结果。最后,从预测结果中返回最大值。

显示与预测图像

首先,我们将创建一个函数,它负责获取图像的路径,然后显示图像和预测结果。

#显示和预测图像的函数
def display_and_predict(img_path_input):
display_img = Image.open(img_path_input)
plt.imshow(display_img)
plt.show()
img = preprocess_img(img_path_input)
pred = predict_result(img)
print("Prediction: ", pred)

上面这个函数display_and_predict首先获取图像的路径并使用PIL库中的Image.open函数打开该图像,然后使用matplotlib库来显示图像,然后将图像传递给preprep_img函数以便输出预测结果,最后使用predict_result函数获得结果并最终打印。

img_input = input("Enter the path of an image: ")
display_and_predict(img_input)

如果您运行上面的程序片断并输入数据集中图像的路径,那么您将得到所期望的输出。

使用遷移學習技術進行深度學習模型的客製化訓練

预测结果示意图

请注意,到目前为止该模型是使用迁移学习技术成功创建的,而无需编写任何一系列神经网络层相关代码。

现在,这个模型可以用于开发能够进行图像识别的Web应用程序了。文章的最后所附链接处提供了如何将该模型应用到Flask应用程序中的完整实现源码。

结论

本文中我们介绍了使用预先训练的模型或迁移学习技术来制作一个定制的深度学习模型的过程。

到目前为止,您已经了解了创建一个完整的深度学习模型所涉及的每一步。归纳起来看,所使用的总体步骤包括:

  • 准备数据集
  • 预处理数据
  • 创建模型
  • 保存自定义模型
  • 测试自定义模型

最后,您可以从​​GitHub​​上获取本文示例项目完整的源代码。

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。

原文标题:Trained A Custom Deep Learning Model Using A Transfer Learning Technique​,作者:Sachin Pal​

以上是使用遷移學習技術進行深度學習模型的客製化訓練的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:51CTO.COM。如有侵權,請聯絡admin@php.cn刪除
易於理解的解釋如何在Chatgpt中建立兩步身份驗證!易於理解的解釋如何在Chatgpt中建立兩步身份驗證!May 12, 2025 pm 05:37 PM

CHATGPT SECURICE增強:兩階段身份驗證(2FA)配置指南 需要兩因素身份驗證(2FA)作為在線平台的安全措施。本文將以易於理解的方式解釋2FA設置過程及其在CHATGPT中的重要性。這是為那些想要安全使用chatgpt的人提供的指南。 單擊此處獲取OpenAI最新的AI代理OpenAi Deep Research⬇️ [chatgpt]什麼是Openai深入研究?關於如何使用它和費用結構的詳盡解釋! 目錄 chatg

[針對企業] Chatgpt培訓|對8種免費培訓選項,補貼和示例進行了詳盡的介紹![針對企業] Chatgpt培訓|對8種免費培訓選項,補貼和示例進行了詳盡的介紹!May 12, 2025 pm 05:35 PM

生成的AI的使用吸引了人們的關注,這是提高業務效率和創造新業務的關鍵。特別是,由於其多功能性和準確性,許多公司都採用了Openai的Chatgpt。但是,可以有效利用chatgpt的人員短缺是實施它的主要挑戰。 在本文中,我們將解釋“ ChatGpt培訓”的必要性和有效性,以確保在公司中成功使用Chatgpt。我們將介紹廣泛的主題,從ChatGpt的基礎到業務使用,特定的培訓計劃以及如何選擇它們。 CHATGPT培訓提高員工技能

關於如何使用Chatgpt簡化您的Twitter操作的詳盡解釋!關於如何使用Chatgpt簡化您的Twitter操作的詳盡解釋!May 12, 2025 pm 05:34 PM

社交媒體運營的提高效率和質量至關重要。特別是在實時重要的平台上,例如Twitter,需要連續交付及時和引人入勝的內容。 在本文中,我們將解釋如何使用具有先進自然語言處理能力的AI的Chatgpt操作Twitter。通過使用CHATGPT,您不僅可以提高實時響應功能並提高內容創建的效率,而且還可以製定符合趨勢的營銷策略。 此外,使用預防措施

[對於Mac]說明如何開始以及如何使用ChatGpt桌面應用程序![對於Mac]說明如何開始以及如何使用ChatGpt桌面應用程序!May 12, 2025 pm 05:33 PM

CHATGPT MAC桌面應用程序詳細指南:從安裝到音頻功能 最後,Chatgpt的Mac桌面應用程序現已可用!在本文中,我們將徹底解釋從安裝方法到有用的功能和將來的更新信息的所有內容。使用桌面應用程序獨有的功能,例如快捷鍵,圖像識別和語音模式,以極大地提高您的業務效率! 安裝桌面應用的ChatGpt Mac版本 從瀏覽器訪問:首先,在瀏覽器中訪問chatgpt。

chatgpt的角色限制是什麼?解釋如何避免它和模型上限chatgpt的角色限制是什麼?解釋如何避免它和模型上限May 12, 2025 pm 05:32 PM

當使用chatgpt時,您是否曾經有過這樣的經驗,例如“輸出在中途停止”或“即使我指定了字符的數量,它也無法正確輸出”?該模型非常開創性,不僅允許自然對話,而且還允許創建電子郵件,摘要論文,甚至允許產生諸如小說之類的創意句子。但是,ChatGpt的弱點之一是,如果文本太長,輸入和輸出將無法正常工作。 Openai的最新AI代理“ Openai Deep Research”

什麼是Chatgpt的語音輸入和語音對話功能?解釋如何設置以及如何使用它什麼是Chatgpt的語音輸入和語音對話功能?解釋如何設置以及如何使用它May 12, 2025 pm 05:27 PM

Chatgpt是Openai開發的創新AI聊天機器人。它不僅具有文本輸入,而且還具有語音輸入和語音對話功能,從而可以進行更自然的交流。 在本文中,我們將解釋如何設置和使用Chatgpt的語音輸入和語音對話功能。即使您不能脫身,Chatp Plans也通過與您交談來做出回應並回應音頻,這在繁忙的商業情況和英語對話練習等各種情況下都帶來了很大的好處。 關於如何設置智能手機應用程序和PC的詳細說明以及如何使用。

易於理解的解釋如何使用Chatgpt進行求職和尋找工作!易於理解的解釋如何使用Chatgpt進行求職和尋找工作!May 12, 2025 pm 05:26 PM

成功的快捷方式!使用chatgpt有效的工作變更策略 在當今加劇的工作變更市場中,有效的信息收集和徹底的準備是成功的關鍵。 諸如Chatgpt之類的高級語言模型是求職者的強大武器。在本文中,我們將解釋如何有效利用Chatgpt來提高您的工作企業效率,從自我分析到申請文件和麵試準備。節省時間和學習技術,以充分展示您的優勢,並幫助您成功搜索工作。 目錄 使用chatgpt的狩獵工作示例 自我分析的效率:聊天

易於理解的解釋如何使用ChatGpt創建和輸出思維地圖!易於理解的解釋如何使用ChatGpt創建和輸出思維地圖!May 12, 2025 pm 05:22 PM

思維地圖是組織信息並提出想法的有用工具,但是創建它們可能需要時間。使用Chatgpt可以大大簡化此過程。 本文將詳細說明如何使用chatgpt輕鬆創建思維地圖。此外,通過創建的實際示例,我們將介紹如何在各種主題上使用思維圖。 了解如何使用Chatgpt有效地組織和可視化您的想法和信息。 Openai的最新AI代理OpenA

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

熱工具

Atom編輯器mac版下載

Atom編輯器mac版下載

最受歡迎的的開源編輯器

SublimeText3 英文版

SublimeText3 英文版

推薦:為Win版本,支援程式碼提示!

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

EditPlus 中文破解版

EditPlus 中文破解版

體積小,語法高亮,不支援程式碼提示功能

DVWA

DVWA

Damn Vulnerable Web App (DVWA) 是一個PHP/MySQL的Web應用程序,非常容易受到攻擊。它的主要目標是成為安全專業人員在合法環境中測試自己的技能和工具的輔助工具,幫助Web開發人員更好地理解保護網路應用程式的過程,並幫助教師/學生在課堂環境中教授/學習Web應用程式安全性。 DVWA的目標是透過簡單直接的介面練習一些最常見的Web漏洞,難度各不相同。請注意,該軟體中