Po wytrenowaniu modelu w Tensorflow:
- Jak zapisać wyszkolony model?
- Jak później przywrócić ten zapisany model?
python
tensorflow
machine-learning
model
mathetes
źródło
źródło
Odpowiedzi:
Dokumenty
wyczerpujący i przydatny samouczek -> https://www.tensorflow.org/guide/saved_model
Szczegółowy przewodnik Keras dotyczący zapisywania modeli -> https://www.tensorflow.org/guide/keras/save_and_serialize
Z dokumentów:
Zapisać
Przywracać
Przepływ Tensor 2
To wciąż wersja beta, więc odradzam na razie. Jeśli nadal chcesz iść tą drogą, tutaj jest
tf.saved_model
instrukcją użytkowaniaPrzepływ Tensor <2
simple_save
Wiele dobrych odpowiedzi, dla kompletności dodam moje 2 centy: simple_save . Również przykład samodzielnego kodu przy użyciu
tf.data.Dataset
API.Python 3; Przepływ Tensor 1.14
Przywracanie:
Przykład samodzielny
Oryginalny post na blogu
Poniższy kod generuje losowe dane na potrzeby demonstracji.
Dataset
a następnie jegoIterator
. Otrzymujemy wygenerowany tensor iteratora, tzwinput_tensor
który posłuży jako dane wejściowe do naszego modelu.input_tensor
: dwukierunkowego RNN opartego na GRU, a następnie gęstego klasyfikatora. Bo czemu nie.softmax_cross_entropy_with_logits
zoptymalizowanaAdam
. Po 2 epokach (po 2 partie każda) zapisujemy „wytrenowany” model za pomocątf.saved_model.simple_save
. Jeśli uruchomisz kod w obecnej postaci, model zostanie zapisany w folderze o nazwiesimple/
w obecnej w bieżącym katalogu roboczym.tf.saved_model.loader.load
. Chwytamy symbole zastępcze i logi za pomocągraph.get_tensor_by_name
iIterator
inicjujemy operację za pomocągraph.get_operation_by_name
.Kod:
Spowoduje to wydrukowanie:
źródło
tf.contrib.layers
?[n.name for n in graph2.as_graph_def().node]
. Jak głosi dokumentacja, proste zapisywanie ma na celu uproszczenie interakcji z obsługą tensorflow, o to właśnie chodzi w argumentach; inne zmienne są jednak nadal przywracane, w przeciwnym razie wnioskowanie nie nastąpiłoby. Po prostu chwyć swoje zmienne zainteresowania, tak jak w przykładzie. Sprawdź dokumentacjęglobal_step
argumentu, jeśli przestaniesz, a następnie spróbujesz ponownie rozpocząć trening, to pomyślisz, że jesteś krok do przodu. Przynajmniej spieszy twoje wizualizacje tensorboardówPoprawiam swoją odpowiedź, aby dodać więcej szczegółów dotyczących zapisywania i przywracania modeli.
W (i po) wersji Tensorflow 0.11 :
Zapisz model:
Przywróć model:
To i niektóre bardziej zaawansowane przypadki użycia zostały tutaj bardzo dobrze wyjaśnione.
Szybki kompletny samouczek do zapisywania i przywracania modeli Tensorflow
źródło
:0
do nazw?W (i późniejszej) wersji TensorFlow 0.11.0RC1 możesz zapisać i przywrócić swój model bezpośrednio, dzwoniąc
tf.train.export_meta_graph
itf.train.import_meta_graph
zgodnie z https://www.tensorflow.org/programmers_guide/meta_graph .Zapisz model
Przywróć model
źródło
<built-in function TF_Run> returned a result with an error set
tf.get_variable_scope().reuse_variables()
a następnievar = tf.get_variable("varname")
. Daje mi to błąd: „Błąd wartości: Zmienna nazwa zmiennej nie istnieje lub nie została utworzona za pomocą tf.get_variable ().” Dlaczego? Czy nie powinno to być możliwe?Dla wersji TensorFlow <0.11.0RC1:
Zapisane punkty kontrolne zawierają wartości dla
Variable
s w modelu, a nie sam model / wykres, co oznacza, że wykres powinien być taki sam podczas przywracania punktu kontrolnego.Oto przykład regresji liniowej, w której istnieje pętla treningowa, która zapisuje zmienne punkty kontrolne, oraz sekcja oceny, która przywróci zmienne zapisane w poprzednim przebiegu i obliczy prognozy. Oczywiście możesz także przywrócić zmienne i kontynuować trening, jeśli chcesz.
Oto dokumenty dotyczące
Variable
s, które obejmują zapisywanie i przywracanie. A oto dokumenty dlaSaver
.źródło
batch_x
musi być? Dwójkowy? Tablica Numpy?undefined
. Czy możesz mi powiedzieć, która z definicji FLAGS dla tego kodu. @RyanSepassiMoje środowisko: Python 3.6, Tensorflow 1.3.0
Chociaż istnieje wiele rozwiązań, większość z nich jest oparta
tf.train.Saver
. Kiedy załadować.ckpt
zapisany przezSaver
musimy albo przedefiniować sieć tensorflow lub użyć trochę dziwne i ciężko pamiętał nazwę, na przykład'placehold_0:0'
,'dense/Adam/Weight:0'
. Tutaj polecam skorzystać ztf.saved_model
jednego najprostszego przykładu podanego poniżej, aby dowiedzieć się więcej na temat obsługi modelu TensorFlow :Zapisz model:
Załaduj model:
źródło
Model składa się z dwóch części, definicji modelu, zapisanej przez
Supervisor
jakgraph.pbtxt
w katalogu modelu oraz wartości liczbowych tensorów, zapisanych w plikach punktów kontrolnych, takich jakmodel.ckpt-1003418
.Definicję modelu można przywrócić za pomocą
tf.import_graph_def
, a wagi przywrócić za pomocąSaver
.Jednakże
Saver
wykorzystuje specjalną kolekcję listę zmiennych, które jest dołączone do modelu Graph gospodarstwa, a ta kolekcja nie jest inicjowany za pomocą import_graph_def, więc nie można korzystać z dwóch razem w tej chwili (jest na naszej mapie drogowej do poprawki). Na razie musisz użyć podejścia Ryana Sepassi - ręcznie skonstruuj wykres z identycznymi nazwami węzłów i użyj,Saver
aby załadować do niego wagi.(Alternatywnie możesz zhakować go, używając
import_graph_def
, tworząc ręcznie zmienne i używająctf.add_to_collection(tf.GraphKeys.VARIABLES, variable)
dla każdej zmiennej, a następnie używającSaver
)źródło
Możesz także wybrać ten łatwiejszy sposób.
Krok 1: zainicjuj wszystkie zmienne
Krok 2: zapisz sesję w modelu
Saver
i zapisz jąKrok 3: przywróć model
Krok 4: sprawdź swoją zmienną
Korzystając z innej instancji Pythona, użyj
źródło
W większości przypadków
tf.train.Saver
najlepszym rozwiązaniem jest zapisywanie i przywracanie z dysku za pomocą :Możesz także zapisać / przywrócić samą strukturę wykresu (szczegóły w dokumentacji MetaGraph ). Domyślnie
Saver
zapisuje strukturę wykresu w.meta
pliku. Możesz zadzwonić,import_meta_graph()
aby go przywrócić. Przywraca strukturę wykresu i zwraca wartośćSaver
, której można użyć do przywrócenia stanu modelu:Są jednak przypadki, w których potrzebujesz czegoś znacznie szybciej. Na przykład, jeśli wdrożysz wczesne zatrzymywanie, chcesz zapisywać punkty kontrolne za każdym razem, gdy model poprawia się podczas treningu (mierzony na podstawie zestawu sprawdzania poprawności), a następnie, jeśli nie ma postępu przez pewien czas, chcesz przywrócić najlepszy model. Jeśli zapiszesz model na dysku za każdym razem, gdy poprawi się, ogromnie spowolni trening. Sztuką jest zapisanie stanów zmiennych w pamięci , a następnie przywrócenie ich później:
Szybkie wyjaśnienie: po utworzeniu zmiennej
X
TensorFlow automatycznie tworzy operację przypisania wX/Assign
celu ustawienia wartości początkowej zmiennej. Zamiast tworzyć symbole zastępcze i dodatkowe operacje przypisania (co spowodowałoby bałagan na wykresie), po prostu używamy tych istniejących operacji przypisania. Pierwsze wejście każdej operacji przypisania jest odwołaniem do zmiennej, którą ma zainicjować, a drugie wejście (assign_op.inputs[1]
) jest wartością początkową. Aby więc ustawić dowolną wartość (zamiast wartości początkowej), musimy użyć afeed_dict
i zastąpić wartość początkową. Tak, TensorFlow pozwala podać wartość dla dowolnej operacji, nie tylko dla symboli zastępczych, więc to działa dobrze.źródło
Jak powiedział Jarosław, możesz zhakować przywracanie z graph_def i punktu kontrolnego, importując wykres, ręcznie tworząc zmienne, a następnie używając wygaszacza.
Zaimplementowałem to na własny użytek, więc pomyślałem, że podzielę się tutaj kodem.
Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(Jest to oczywiście włamanie i nie ma gwarancji, że zapisane w ten sposób modele pozostaną czytelne w przyszłych wersjach TensorFlow.)
źródło
Jeśli jest to model zapisany wewnętrznie, po prostu określasz restauratora dla wszystkich zmiennych jako
i użyj go do przywrócenia zmiennych w bieżącej sesji:
W przypadku modelu zewnętrznego musisz określić odwzorowanie jego nazw zmiennych na nazwy zmiennych. Możesz wyświetlić nazwy zmiennych modelu za pomocą polecenia
Skrypt inspect_checkpoint.py można znaleźć w folderze „./tensorflow/python/tools” źródła Tensorflow.
Aby określić mapowanie, możesz użyć mojego Tensorflow-Worklab , który zawiera zestaw klas i skryptów do trenowania i przekwalifikowywania różnych modeli. Zawiera przykład przekwalifikowania modeli ResNet, który znajduje się tutaj
źródło
all_variables()
jest teraz przestarzałeOto moje proste rozwiązanie dwóch podstawowych przypadków różniących się od tego, czy chcesz załadować wykres z pliku, czy skompilować go w czasie wykonywania.
Ta odpowiedź dotyczy Tensorflow 0.12+ (w tym 1.0).
Przebudowa wykresu w kodzie
Oszczędność
Ładowanie
Ładowanie również wykresu z pliku
Korzystając z tej techniki, upewnij się, że wszystkie twoje warstwy / zmienne mają jawnie ustawione unikalne nazwy.W przeciwnym razie Tensorflow sprawi, że nazwy będą unikalne, a zatem będą różne od nazw przechowywanych w pliku. W poprzedniej technice nie stanowi to problemu, ponieważ nazwy są „zniekształcane” w ten sam sposób zarówno podczas ładowania, jak i zapisywania.
Oszczędność
Ładowanie
źródło
global_step
zmienna i średnie kroczące normalizacji partii są zmiennymi, których nie da się wyćwiczyć, ale zdecydowanie warto je zapisać. Ponadto należy wyraźniej odróżnić budowę wykresu od uruchomienia sesji, na przykładSaver(...).save()
będzie tworzyć nowe węzły za każdym razem, gdy go uruchomisz. Prawdopodobnie nie to, czego chcesz. I jest więcej ...: /Możesz także sprawdzić przykłady w TensorFlow / skflow , który oferuje metody
save
irestore
metody, które pomogą ci łatwo zarządzać swoimi modelami. Ma parametry, które możesz kontrolować, jak często chcesz tworzyć kopię zapasową modelu.źródło
Jeśli używasz tf.train.MonitoredTrainingSession jako sesji domyślnej, nie musisz dodawać dodatkowego kodu, aby zapisywać / przywracać rzeczy. Po prostu przekaż nazwę kontrolną punktu kontrolnego konstruktorowi MonitoredTrainingSession, użyje haków sesji do ich obsługi.
źródło
Wszystkie odpowiedzi tutaj są świetne, ale chcę dodać dwie rzeczy.
Po pierwsze, aby rozwinąć odpowiedź na temat @ user7505159, „./” może być ważne, aby dodać na początku przywracanej nazwy pliku.
Na przykład możesz zapisać wykres bez nazwy „./” w nazwie pliku w następujący sposób:
Ale w celu przywrócenia wykresu może być konieczne dodanie „./” do nazwy pliku:
Nie zawsze będziesz potrzebować „./”, ale może to powodować problemy w zależności od środowiska i wersji TensorFlow.
Warto również wspomnieć, że
sess.run(tf.global_variables_initializer())
może to być ważne przed przywróceniem sesji.Jeśli pojawia się błąd dotyczący niezainicjowanych zmiennych podczas próby przywrócenia zapisanej sesji, upewnij się, że podałeś ją
sess.run(tf.global_variables_initializer())
przedsaver.restore(sess, save_file)
wierszem. Może zaoszczędzić ci bólu głowy.źródło
Jak opisano w numerze 6255 :
zamiast
źródło
Według nowej wersji Tensorflow
tf.train.Checkpoint
preferowanym sposobem zapisywania i przywracania modelu jest:Oto przykład:
Więcej informacji i przykład tutaj.
źródło
W przypadku tensorflow 2.0 jest to tak proste
Przywrócić:
źródło
tf.keras Zapisywanie modelu za pomocą
TF2.0
Widzę świetne odpowiedzi na temat zapisywania modeli za pomocą TF1.x. Chcę podać kilka dodatkowych wskazówek w zapisywaniu
tensorflow.keras
modeli, co jest nieco skomplikowane, ponieważ istnieje wiele sposobów zapisywania modelu.Podaję przykład zapisywania
tensorflow.keras
modelu wmodel_path
folderze w bieżącym katalogu. Działa to dobrze z najnowszym tensorflow (TF2.0). Zaktualizuję ten opis, jeśli w najbliższej przyszłości nastąpi jakakolwiek zmiana.Zapisywanie i ładowanie całego modelu
Zapisywanie i ładowanie modelu Tylko masy
Jeśli chcesz zapisać tylko masy modelu, a następnie załaduj wagi, aby przywrócić model, to
Zapisywanie i przywracanie za pomocą oddzwaniania punktu kontrolnego keras
zapisywanie modelu z niestandardowymi danymi
Zapisywanie modelu Keras z niestandardowymi operacjami
Kiedy mamy niestandardowe operacje, jak w poniższym przypadku (
tf.tile
), musimy utworzyć funkcję i owinąć ją warstwą Lambda. W przeciwnym razie model nie zostanie zapisany.Myślę, że omówiłem kilka z wielu sposobów zapisywania modelu tf.keras. Istnieje jednak wiele innych sposobów. Skomentuj poniżej, jeśli widzisz, że Twój przypadek użycia nie jest uwzględniony powyżej. Dzięki!
źródło
Aby zapisać model, użyj tf.train.Saver, pamiętaj, że jeśli chcesz zmniejszyć rozmiar modelu, musisz podać listę var_list. Val_list może być tf.trainable_variables lub tf.global_variables.
źródło
Możesz zapisać zmienne w sieci za pomocą
Aby przywrócić sieć do ponownego użycia później lub w innym skrypcie, użyj:
Ważne punkty:
sess
muszą być takie same między pierwszym a późniejszym przebiegiem (spójna struktura).saver.restore
potrzebuje ścieżki do folderu zapisanych plików, a nie pojedynczej ścieżki pliku.źródło
Gdziekolwiek chcesz zapisać model,
Upewnij się, że wszystkie
tf.Variable
mają nazwy, ponieważ możesz je później przywrócić, używając ich nazw. I gdzie chcesz przewidzieć,Upewnij się, że wygaszacz działa w odpowiedniej sesji. Pamiętaj, że jeśli użyjesz, zostanie użyty
tf.train.latest_checkpoint('./')
tylko najnowszy punkt kontrolny.źródło
Jestem w wersji:
Prosty sposób to
Zapisać:
Przywracać:
źródło
Dla tensorflow-2.0
to jest bardzo proste.
ZAPISAĆ
PRZYWRACAĆ
źródło
Po odpowiedzi @Vishnuvardhan Janapati, oto kolejny sposób na zapisanie i ponowne załadowanie modelu z niestandardową warstwą / metryką / utratą w TensorFlow 2.0.0
W ten sposób, kiedy już wykonywane takie kody, a zapisany model z
tf.keras.models.save_model
lubmodel.save
lubModelCheckpoint
oddzwanianie, można ponownie załadować model bez konieczności precyzyjnego niestandardowych obiektów, tak proste, jakźródło
W nowej wersji tensorflow 2.0 proces zapisywania / ładowania modelu jest znacznie łatwiejszy. Ze względu na implementację API Keras, API wysokiego poziomu dla TensorFlow.
Aby zapisać model: sprawdź dokumentację w celach informacyjnych: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model
Aby załadować model:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model
źródło
Oto prosty przykład za pomocą Tensorflow 2.0 SavedModel formatu (który jest zalecany format Według docs ) za pomocą prostego zestawu danych MNIST klasyfikatora, wykorzystując Keras API funkcjonalny bez zbyt fantazyjne dzieje:
Co to jest
serving_default
?Jest to nazwa def podpisu wybranego znacznika (w tym przypadku
serve
wybrano domyślny znacznik). Również tutaj wyjaśnia, jak znaleźć-tych tagów i podpisów z wykorzystaniem modelusaved_model_cli
.Zastrzeżenia
To tylko podstawowy przykład, jeśli chcesz go uruchomić, ale w żadnym wypadku nie jest to kompletna odpowiedź - być może uda mi się go zaktualizować w przyszłości. Chciałem tylko podać prosty przykład z wykorzystaniem
SavedModel
TF 2.0, ponieważ nigdzie nie widziałem takiego, nawet takiego prostego.Odpowiedź @ Toma to przykład SavedModel, ale nie będzie działać na Tensorflow 2.0, ponieważ niestety są pewne przełomowe zmiany.
@ Odpowiedź Vishnuvardhan Janapati mówi TF 2.0, ale nie dotyczy formatu SavedModel.
źródło