搜尋
首頁後端開發Python教學pytorch模型保存與載入中的一些問題實戰記錄

本篇文章為大家帶來了關於Python的相關知識,其中主要介紹了關於pytorch模型保存與加載中的一些問題實戰記錄,下面一起來看一下,希望對大家有幫助。

【相關推薦:Python3影片教學

一、torch中模型保存與載入的方式

1、模型參數和模型結構保存與載入

torch.save(model,path)
torch.load(path)

2、只儲存模型的參數和載入-這種方式比較安全,但是比較稍微麻煩一點點

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)

二、 torch中模型保存與載入出現的問題

1、單卡模型下儲存模型結構、參數後載入出現的問題

模型儲存的時候會把模型結構定義檔路徑記錄下來,載入的時候就會根據路徑解析它然後裝載參數;當把模型定義檔案路徑修改以後,使用torch.load(path)就會報錯。

#把model資料夾修改成models後,再載入就會報錯。

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)

這種保存完整模型結構與參數的方式,一定不要改變模型定義檔路徑

2、多卡機器單卡訓練模型保存後在單卡機器上載入會報錯

在多卡機器上有多張顯示卡0號開始,現在模型在n>= 1上的顯示卡訓練儲存後,拷貝在單卡機器上載入

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)

#會出現cuda device不符的問題-你儲存的模碼段小部件型是使用的cuda1,那麼採用torch.load()開啟的時候,會預設的去尋找cuda1,然後把模型載入到該裝置上。這時候可以直接使用map_location來解決,把模型載入到CPU上即可。

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))

3、多卡訓練模型保存模型結構和參數後加載出現的問題

當用多GPU同時訓練模型之後,不管是採用模型結構和參數一起保存還是單獨保存模型參數,然後在單卡下載入都會出現問題

a、模型結構和參數一起保然後在載入

torch.distributed.init_process_group(backend='nccl')

模型訓練的時候採用上述多進程的方式,所以你在載入的時候也要聲明,不然就會報錯。

b、單獨保存模型參數

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)

同樣會出現問題,不過這裡出現的問題是參數字典的key和模型定義的key不一樣

原因是多GPU訓練下,使用分散式訓練的時候會給模型一個包裝,程式碼如下:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)

包裝前的模型結構:

包裝後的模型

在外層多了DistributedDataParallel以及module,所以才會導致在單卡環境下載入模型權重的時候出現權重的keys不一致。

三、正確的保存模型和載入的方法

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)

這樣就是比較好的範式,載入不會出錯。

【相關推薦:Python3影片教學

以上是pytorch模型保存與載入中的一些問題實戰記錄的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:脚本之家。如有侵權,請聯絡admin@php.cn刪除
Python和時間:充分利用您的學習時間Python和時間:充分利用您的學習時間Apr 14, 2025 am 12:02 AM

要在有限的時間內最大化學習Python的效率,可以使用Python的datetime、time和schedule模塊。 1.datetime模塊用於記錄和規劃學習時間。 2.time模塊幫助設置學習和休息時間。 3.schedule模塊自動化安排每週學習任務。

Python:遊戲,Guis等Python:遊戲,Guis等Apr 13, 2025 am 12:14 AM

Python在遊戲和GUI開發中表現出色。 1)遊戲開發使用Pygame,提供繪圖、音頻等功能,適合創建2D遊戲。 2)GUI開發可選擇Tkinter或PyQt,Tkinter簡單易用,PyQt功能豐富,適合專業開發。

Python vs.C:申請和用例Python vs.C:申請和用例Apr 12, 2025 am 12:01 AM

Python适合数据科学、Web开发和自动化任务,而C 适用于系统编程、游戏开发和嵌入式系统。Python以简洁和强大的生态系统著称,C 则以高性能和底层控制能力闻名。

2小時的Python計劃:一種現實的方法2小時的Python計劃:一種現實的方法Apr 11, 2025 am 12:04 AM

2小時內可以學會Python的基本編程概念和技能。 1.學習變量和數據類型,2.掌握控制流(條件語句和循環),3.理解函數的定義和使用,4.通過簡單示例和代碼片段快速上手Python編程。

Python:探索其主要應用程序Python:探索其主要應用程序Apr 10, 2025 am 09:41 AM

Python在web開發、數據科學、機器學習、自動化和腳本編寫等領域有廣泛應用。 1)在web開發中,Django和Flask框架簡化了開發過程。 2)數據科學和機器學習領域,NumPy、Pandas、Scikit-learn和TensorFlow庫提供了強大支持。 3)自動化和腳本編寫方面,Python適用於自動化測試和系統管理等任務。

您可以在2小時內學到多少python?您可以在2小時內學到多少python?Apr 09, 2025 pm 04:33 PM

兩小時內可以學到Python的基礎知識。 1.學習變量和數據類型,2.掌握控制結構如if語句和循環,3.了解函數的定義和使用。這些將幫助你開始編寫簡單的Python程序。

如何在10小時內通過項目和問題驅動的方式教計算機小白編程基礎?如何在10小時內通過項目和問題驅動的方式教計算機小白編程基礎?Apr 02, 2025 am 07:18 AM

如何在10小時內教計算機小白編程基礎?如果你只有10個小時來教計算機小白一些編程知識,你會選擇教些什麼�...

如何在使用 Fiddler Everywhere 進行中間人讀取時避免被瀏覽器檢測到?如何在使用 Fiddler Everywhere 進行中間人讀取時避免被瀏覽器檢測到?Apr 02, 2025 am 07:15 AM

使用FiddlerEverywhere進行中間人讀取時如何避免被檢測到當你使用FiddlerEverywhere...

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
4 週前By尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解鎖Myrise中的所有內容
1 個月前By尊渡假赌尊渡假赌尊渡假赌

熱工具

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

SAP NetWeaver Server Adapter for Eclipse

SAP NetWeaver Server Adapter for Eclipse

將Eclipse與SAP NetWeaver應用伺服器整合。

EditPlus 中文破解版

EditPlus 中文破解版

體積小,語法高亮,不支援程式碼提示功能

DVWA

DVWA

Damn Vulnerable Web App (DVWA) 是一個PHP/MySQL的Web應用程序,非常容易受到攻擊。它的主要目標是成為安全專業人員在合法環境中測試自己的技能和工具的輔助工具,幫助Web開發人員更好地理解保護網路應用程式的過程,並幫助教師/學生在課堂環境中教授/學習Web應用程式安全性。 DVWA的目標是透過簡單直接的介面練習一些最常見的Web漏洞,難度各不相同。請注意,該軟體中