Rumah >pembangunan bahagian belakang >Tutorial Python >Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch
Artikel ini membawakan anda pengetahuan yang berkaitan tentang Python terutamanya memperkenalkan rekod praktikal tentang beberapa masalah dalam menyimpan dan memuatkan model pytorch. Saya harap ia dapat membantu semua orang. membantu.
[Cadangan berkaitan: Tutorial video Python3 ]
torch.save(model,path) torch.load(path)
torch.save(model.state_dict(),path) model_state_dic = torch.load(path) model.load_state_dic(model_state_dic)
Struktur model akan ditakrifkan apabila model disimpan. laluan fail direkodkan, dan apabila memuatkan, ia akan dihuraikan mengikut laluan dan parameter akan dimuatkan selepas laluan fail definisi model diubah suai, ralat akan dilaporkan apabila menggunakan torch.load(path).
Selepas menukar folder model kepada model, ralat akan dilaporkan apabila memuatkan semula.
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN.bin') print('load_model',load_model)
Dengan cara ini untuk menyimpan struktur dan parameter model lengkap, pastikan anda tidak menukar laluan fail definisi model .
Bermula dari 0 pada berbilang -mesin kad dengan berbilang kad grafik, model kini n>= Selepas menyimpan latihan kad grafik pada 1 dan memuatkan salinan pada mesin kad tunggal
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin') print('load_model',load_model)
akan ada masalah ketidakpadanan peranti cuda - segmen kod modul yang anda simpan adalah kecil Jenis komponen menggunakan cuda1, jadi apabila ia dibuka menggunakan torch.load(), ia akan mencari cuda1 secara lalai, dan kemudian memuatkan model. kepada peranti. Pada masa ini, anda boleh terus menggunakan map_location untuk menyelesaikan masalah dan memuatkan model ke CPU.
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
Apabila menggunakan berbilang GPU untuk melatih model pada masa yang sama. , sama ada struktur model dan parameter disimpan bersama atau Masalah akan berlaku jika anda menyimpan parameter model secara berasingan dan kemudian memuatkannya di bawah satu kad
a Simpan struktur model dan parameter bersama-sama dan kemudian memuatkan model
torch.distributed.init_process_group(backend='nccl')
Kaedah berbilang proses di atas digunakan semasa latihan, jadi anda mesti mengisytiharkannya semasa memuatkan, jika tidak ralat akan dilaporkan.
b. Menyimpan parameter model secara berasingan
model = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load('train_model/clip/experiment.pt') model.load_state_dict(state_dict)
juga akan menyebabkan masalah, tetapi masalahnya di sini ialah kunci kamus parameter berbeza daripada kunci yang ditakrifkan oleh model
Sebabnya di bawah latihan multi-GPU, apabila menggunakan latihan yang diedarkan, model tersebut akan dibungkus seperti berikut:
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin') print(model) model.cuda(args.local_rank) 。。。。。。 model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True) print('model',model)
Struktur model sebelum pembungkusan:
Model berpakej
mempunyai DistributedDataParallel dan modul di lapisan luar, jadi berat model adalah dimuatkan dalam persekitaran kad tunggal Apabila kekunci berat tidak konsisten.
if gpu_count > 1: torch.save(model.module.state_dict(),save_path) else: torch.save(model.state_dict(),save_path) model = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load(save_path) model.load_state_dict(state_dict)
Ini adalah paradigma yang lebih baik, dan tidak akan ada ralat dalam memuatkan.
[Cadangan berkaitan: Tutorial video Python3]
Atas ialah kandungan terperinci Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!