Szukałem alternatywnych sposobów zapisania wytrenowanego modelu w PyTorch. Jak dotąd znalazłem dwie alternatywy.
- torch.save (), aby zapisać model i torch.load (), aby załadować model.
- model.state_dict (), aby zapisać wytrenowany model i model.load_state_dict (), aby załadować zapisany model.
Natknąłem się na tę dyskusję, w której podejście 2 jest zalecane zamiast podejścia 1.
Moje pytanie brzmi: dlaczego preferowane jest drugie podejście? Czy to tylko dlatego, że moduły torch.nn mają te dwie funkcje i zachęcamy do ich używania?
python
serialization
deep-learning
pytorch
tensor
Wasi Ahmad
źródło
źródło
torch.save(model, f)
itorch.save(model.state_dict(), f)
. Zapisane pliki mają ten sam rozmiar. Teraz jestem zmieszany. Zauważyłem również, że używanie pickle do zapisywania model.state_dict () jest bardzo wolne. Myślę, że najlepszym sposobem jest użycie,torch.save(model.state_dict(), f)
ponieważ zajmujesz się tworzeniem modelu, a latarka obsługuje ładowanie ciężarów modelu, eliminując w ten sposób możliwe problemy. Źródłapickle
?Odpowiedzi:
Znalazłem tę stronę w ich repozytorium github, po prostu wkleję tutaj zawartość.
Zalecane podejście do zapisywania modelu
Istnieją dwa główne podejścia do serializacji i przywracania modelu.
Pierwsza (zalecana) zapisuje i wczytuje tylko parametry modelu:
Później:
Drugi zapisuje i wczytuje cały model:
Później:
Jednak w tym przypadku serializowane dane są powiązane z określonymi klasami i dokładną używaną strukturą katalogów, więc mogą ulec uszkodzeniu na różne sposoby, gdy są używane w innych projektach lub po poważnych refaktorach.
źródło
pickle
?To zależy od tego, co chcesz robić.
Przypadek 1: Zapisz model, aby użyć go samodzielnie do wnioskowania : Zapisujesz model, przywracasz go, a następnie zmieniasz model w tryb oceny. Dzieje się tak, ponieważ zwykle masz warstwy
BatchNorm
i,Dropout
które domyślnie są w trybie pociągu podczas budowy:Przypadek 2: Zapisz model, aby wznowić uczenie później : Jeśli chcesz dalej trenować model, który zamierzasz zapisać, musisz zapisać więcej niż tylko model. Musisz także zapisać stan optymalizatora, epoki, wynik itp. Zrobisz to w następujący sposób:
Aby wznowić szkolenie, wykonaj następujące czynności:,
state = torch.load(filepath)
a następnie, aby przywrócić stan każdego pojedynczego obiektu, coś takiego:Ponieważ wznawiasz szkolenie, NIE dzwoń
model.eval()
po przywróceniu stanów podczas ładowania.Przypadek # 3: Model, który ma być używany przez inną osobę bez dostępu do Twojego kodu : W Tensorflow możesz utworzyć
.pb
plik, który definiuje zarówno architekturę, jak i wagi modelu. Jest to bardzo przydatne, szczególnie podczas używaniaTensorflow serve
. Równoważny sposób na zrobienie tego w Pytorch to:W ten sposób nadal nie jest kuloodporny, a ponieważ pytorch wciąż przechodzi wiele zmian, nie polecałbym go.
źródło
torch.load
zwraca tylko OrderedDict. Jak uzyskać model, aby móc przewidywać?marynata Python implementuje protokoły binarne do serializacji i deserializacji obiektu Pythona.
Kiedy ty
import torch
(lub gdy używasz PyTorch) będzie toimport pickle
dla ciebie i nie musisz wywoływaćpickle.dump()
ipickle.load()
bezpośrednio, które są metodami zapisywania i ładowania obiektu.W rzeczywistości
torch.save()
itorch.load()
zawiniepickle.dump()
ipickle.load()
dla Ciebie.ZA
state_dict
Druga odpowiedź wspomniano zasługuje tylko kilka dodatkowych uwag.Co
state_dict
mamy w PyTorch? Właściwie są dwastate_dict
.Model PyTorch
torch.nn.Module
mamodel.parameters()
wywołanie, aby uzyskać parametry, których można się nauczyć (w i b). Te parametry, których można się nauczyć, raz ustawione losowo, będą aktualizowane w miarę upływu czasu. Parametry, których można się nauczyć, są pierwszymistate_dict
.Drugi
state_dict
to dyktowanie stanu optymalizatora. Przypominasz sobie, że optymalizator służy do poprawy parametrów, których można się nauczyć. Ale optymalizatorstate_dict
jest naprawiony. Nie ma się tam czego nauczyć.Ponieważ
state_dict
obiekty są słownikami Pythona, można je łatwo zapisywać, aktualizować, zmieniać i przywracać, dodając wiele modułowości do modeli i optymalizatorów PyTorch.Stwórzmy super prosty model, aby to wyjaśnić:
Ten kod wyświetli następujące informacje:
Zwróć uwagę, że jest to model minimalny. Możesz spróbować dodać stos sekwencyjny
Należy zauważyć, że tylko warstwy z parametrami, których można się nauczyć (warstwy splotowe, warstwy liniowe itp.) I zarejestrowane bufory (warstwy normalne partii) mają wpisy w modelu
state_dict
.Rzeczy, których nie można się nauczyć, należą do obiektu optymalizatora
state_dict
, który zawiera informacje o stanie optymalizatora, a także o zastosowanych hiperparametrach.Reszta historii jest taka sama; w fazie wnioskowania (jest to faza, w której używamy modelu po treningu) do prognozowania; przewidujemy na podstawie parametrów, których się nauczyliśmy. Tak więc do wnioskowania wystarczy zapisać parametry
model.state_dict()
.I użyć później model.load_state_dict (torch.load (filepath)) model.eval ()
Uwaga: nie zapomnij o ostatniej linii,
model.eval()
jest to kluczowe po załadowaniu modelu.Nie próbuj też oszczędzać
torch.save(model.parameters(), filepath)
. Tomodel.parameters()
tylko obiekt generatora.Z drugiej strony
torch.save(model, filepath)
zapisuje sam obiekt modelu, ale pamiętaj, że model nie ma optymalizatorastate_dict
. Sprawdź inną doskonałą odpowiedź autorstwa @Jadiel de Armas, aby zapisać dyktando stanu optymalizatora.źródło
Powszechną konwencją PyTorch jest zapisywanie modeli przy użyciu rozszerzenia pliku .pt lub .pth.
Zapisz / wczytaj cały model Zapisz:
Załaduj:
Klasa modelu musi być gdzieś zdefiniowana
źródło
Jeśli chcesz zapisać model i później wznowić trening:
Pojedynczy GPU: Zapisz:
Załaduj:
Wiele GPU: Zapisz
Załaduj:
źródło