이 글은 Python에 대한 관련 지식을 소개합니다. 주로 Pytorch 모델을 저장하고 로드할 때 발생하는 몇 가지 문제에 대한 실제 기록을 소개합니다. 모두에게 도움이 되기를 바랍니다.
【관련 권장사항: Python3 동영상 튜토리얼】
torch.save(model,path) torch.load(path)
torch.save(model.state_dict(),path) model_state_dic = torch.load(path) model.load_state_dic(model_state_dic)
모델 저장 시 모델 구조 정의 파일의 경로가 기록되며, 로드 시 경로에 따라 파싱된 후 모델 정의 파일의 경로가 수정되면 오류가 보고됩니다. torch.load(경로)를 사용할 때.
모델 폴더를 모델로 변경한 후 다시 로드할 때 오류가 발생합니다.
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN.bin') print('load_model',load_model)
이러한 방식으로 전체 모델 구조와 매개변수를 저장하려면 모델 정의 파일 경로를 변경하지 마세요.
여러 그래픽 카드가 있는 다중 카드 기계에서는 0부터 시작합니다. 이제 모델은 그래픽 카드를 저장한 후 n>=1로 훈련됩니다. 복사본이 단일 카드 머신
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin') print('load_model',load_model)
에 로드되면 cuda 장치 불일치 문제(모델 코드 세그먼트 위젯 유형)가 발생합니다. 저장한 것은 cuda1이므로 torch.load()로 열면 기본적으로 cuda1을 찾은 다음 모델을 장치에 로드합니다. 이때 map_location을 직접 사용하여 문제를 해결하고 모델을 CPU에 로드할 수 있습니다.
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
여러 GPU로 모델을 동시에 훈련한 후 모델 구조와 매개변수를 함께 저장하는지 아니면 모델을 저장하는지. 매개변수를 별도로 저장한 다음 단일 카드에
a를 로드할 때 문제가 발생합니다. 모델 구조와 매개변수를 함께 저장한 다음
torch.distributed.init_process_group(backend='nccl')
모델 훈련을 로드할 때 위의 다중 프로세스 방법을 사용합니다. 이므로 로드할 때도 선언해야 합니다. 그렇지 않으면 오류가 보고됩니다.
b. 모델 매개변수를 별도로 저장하는 것도 문제가 발생하지만 여기서 문제는 매개변수 사전의 키가 모델에서 정의한 키와 다르다는 것입니다
이유는 다중 GPU에서 발생하기 때문입니다. training, distributed training이 사용됩니다. 모델은 언젠가 패키징될 예정이며, 코드는 다음과 같습니다.
model = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load('train_model/clip/experiment.pt') model.load_state_dict(state_dict)
패키징 전 모델 구조:
패키지된 모델
더 많은 DistributedDataParallel 및 모듈이 있습니다. 모델 가중치를 로드할 때 가중치 키가 일치하지 않습니다.
3. 모델을 저장하고 로드하는 올바른 방법
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin') print(model) model.cuda(args.local_rank) 。。。。。。 model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True) print('model',model)
【관련 추천:
Python3 비디오 튜토리얼위 내용은 Pytorch 모델을 저장하고 로드할 때 발생하는 몇 가지 문제에 대한 실제 기록의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!