Rumah >pembangunan bahagian belakang >Tutorial Python >Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch

Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch

WBOY
WBOYke hadapan
2022-11-03 17:33:322706semak imbas

Artikel ini membawakan anda pengetahuan yang berkaitan tentang Python terutamanya memperkenalkan rekod praktikal tentang beberapa masalah dalam menyimpan dan memuatkan model pytorch. Saya harap ia dapat membantu semua orang. membantu.

[Cadangan berkaitan: Tutorial video Python3 ]

1. Cara menyimpan dan memuatkan model dalam obor

1. Simpan dan muatkan parameter model dan struktur model

torch.save(model,path)
torch.load(path)

2. Hanya simpan dan muatkan parameter model - kaedah ini lebih selamat, tetapi lebih menyusahkan

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)

2. Masalah dalam menyimpan dan memuatkan model dalam obor

1 Masalah dalam memuatkan struktur dan parameter model selepas menyimpannya dalam model kad tunggal

Struktur model akan ditakrifkan apabila model disimpan. laluan fail direkodkan, dan apabila memuatkan, ia akan dihuraikan mengikut laluan dan parameter akan dimuatkan selepas laluan fail definisi model diubah suai, ralat akan dilaporkan apabila menggunakan torch.load(path).

Selepas menukar folder model kepada model, ralat akan dilaporkan apabila memuatkan semula.

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)

Dengan cara ini untuk menyimpan struktur dan parameter model lengkap, pastikan anda tidak menukar laluan fail definisi model .

2. Selepas menyimpan model latihan kad tunggal pada mesin berbilang kad, ralat akan dilaporkan semasa memuatkannya pada mesin kad tunggal

Bermula dari 0 pada berbilang -mesin kad dengan berbilang kad grafik, model kini n>= Selepas menyimpan latihan kad grafik pada 1 dan memuatkan salinan pada mesin kad tunggal

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)

akan ada masalah ketidakpadanan peranti cuda - segmen kod modul yang anda simpan adalah kecil Jenis komponen menggunakan cuda1, jadi apabila ia dibuka menggunakan torch.load(), ia akan mencari cuda1 secara lalai, dan kemudian memuatkan model. kepada peranti. Pada masa ini, anda boleh terus menggunakan map_location untuk menyelesaikan masalah dan memuatkan model ke CPU.

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))

3 Masalah yang berlaku selepas menyimpan struktur model dan parameter model latihan berbilang GPU dan kemudian memuatkannya

Apabila menggunakan berbilang GPU untuk melatih model pada masa yang sama. , sama ada struktur model dan parameter disimpan bersama atau Masalah akan berlaku jika anda menyimpan parameter model secara berasingan dan kemudian memuatkannya di bawah satu kad

a Simpan struktur model dan parameter bersama-sama dan kemudian memuatkan model

torch.distributed.init_process_group(backend='nccl')

Kaedah berbilang proses di atas digunakan semasa latihan, jadi anda mesti mengisytiharkannya semasa memuatkan, jika tidak ralat akan dilaporkan.

b. Menyimpan parameter model secara berasingan

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)

juga akan menyebabkan masalah, tetapi masalahnya di sini ialah kunci kamus parameter berbeza daripada kunci yang ditakrifkan oleh model

Sebabnya di bawah latihan multi-GPU, apabila menggunakan latihan yang diedarkan, model tersebut akan dibungkus seperti berikut:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)

Struktur model sebelum pembungkusan:

Model berpakej

mempunyai DistributedDataParallel dan modul di lapisan luar, jadi berat model adalah dimuatkan dalam persekitaran kad tunggal Apabila kekunci berat tidak konsisten.

3. Kaedah menyimpan dan memuatkan model yang betul

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)

Ini adalah paradigma yang lebih baik, dan tidak akan ada ralat dalam memuatkan.

[Cadangan berkaitan: Tutorial video Python3]

Atas ialah kandungan terperinci Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:jb51.net. Jika ada pelanggaran, sila hubungi admin@php.cn Padam