Heim >Backend-Entwicklung >Python-Tutorial >Praktische Aufzeichnungen einiger Probleme beim Speichern und Laden von Pytorch-Modellen

Praktische Aufzeichnungen einiger Probleme beim Speichern und Laden von Pytorch-Modellen

WBOY
WBOYnach vorne
2022-11-03 17:33:322707Durchsuche

Dieser Artikel vermittelt Ihnen relevantes Wissen über Python. Er stellt hauptsächlich praktische Aufzeichnungen zu einigen Problemen beim Speichern und Laden von Pytorch-Modellen vor. Ich hoffe, dass er für alle hilfreich ist.

【Verwandte Empfehlungen: Python3-Video-Tutorial

1. So speichern und laden Sie Modelle in Torch

1. Speichern und laden Sie Modellparameter und Modellstrukturen

rrree

2 Laden des Modells – Diese Methode ist sicherer, aber etwas aufwändiger. 2. Probleme beim Speichern und Laden von Modellen im Brenner

Modell Beim Speichern wird der Pfad zur Modellstrukturdefinitionsdatei aufgezeichnet, dieser wird entsprechend dem Pfad analysiert und dann mit Parametern geladen. Wenn der Pfad zur Modelldefinitionsdatei geändert wird, wird ein Fehler gemeldet bei Verwendung von Torch.load(path).

Nach dem Ändern des Modellordners in „Modelle“ wird beim erneuten Laden ein Fehler gemeldet.

torch.save(model,path)
torch.load(path)

Damit Sie die vollständige Modellstruktur und die Parameter speichern, achten Sie darauf, den Pfad der Modelldefinitionsdatei nicht zu ändern.

2. Nach dem Speichern des Einzelkarten-Trainingsmodells auf einem Computer mit mehreren Karten wird beim Laden auf einem Computer mit einer Karte ein Fehler gemeldet.

Beginnend bei 0 auf einem Computer mit mehreren Karten. Jetzt wird das Modell nach dem Speichern der Grafikkarte auf n>=1 trainiert

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)

, es wird ein Problem mit der Nichtübereinstimmung des Cuda-Geräts geben – dem Modellcode-Segment-Widget-Typ Sie haben cuda1 gespeichert. Wenn Sie es also mit Torch.load() öffnen, wird standardmäßig cuda1 gesucht und dann das Modell auf das Gerät geladen. Zu diesem Zeitpunkt können Sie map_location direkt verwenden, um das Problem zu lösen und das Modell auf die CPU zu laden.

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)

3. Probleme, die auftreten, wenn Multi-Card-Trainingsmodelle die Modellstruktur und -parameter speichern und dann laden.

Nach dem Training des Modells mit mehreren GPUs gleichzeitig, unabhängig davon, ob die Modellstruktur und -parameter zusammen oder das Modell gespeichert werden Die Parameter werden separat gespeichert und dann unter einer einzigen Karte. Beim Laden von

a treten Probleme auf. Speichern Sie die Modellstruktur und die Parameter zusammen und verwenden Sie dann beim Laden des

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)

Modelltrainings die oben beschriebene Mehrprozessmethode, also müssen Sie Deklarieren Sie es auch beim Laden, andernfalls wird ein Fehler gemeldet.

b. Das separate Speichern von Modellparametern

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
führt ebenfalls zu Problemen, aber das Problem besteht darin, dass sich der Schlüssel des Parameterwörterbuchs von dem vom Modell definierten Schlüssel unterscheidet. Der Grund dafür ist, dass unter Multi-GPU Training, verteiltes Training wird verwendet. Das Modell wird irgendwann gepackt, und der Code lautet wie folgt:

torch.distributed.init_process_group(backend='nccl')

Die Modellstruktur vor dem Packen:

Das gepackte Modell

Es gibt mehr DistributedDataParallel und Module in der äußeren Schicht, sodass es zu einer Einzelkartenumgebung kommt. Beim Laden von Modellgewichten sind die Gewichtsschlüssel inkonsistent.

3. Der richtige Weg, das Modell zu speichern und zu laden

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)
Dies ist ein besseres Paradigma und es treten keine Fehler beim Laden auf.

【Verwandte Empfehlungen:

Python3-Video-Tutorial

Das obige ist der detaillierte Inhalt vonPraktische Aufzeichnungen einiger Probleme beim Speichern und Laden von Pytorch-Modellen. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:jb51.net. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen