Z tego, co do tej pory zebrałem, istnieje kilka różnych sposobów zrzucenia wykresu TensorFlow do pliku, a następnie załadowania go do innego programu, ale nie udało mi się znaleźć jasnych przykładów / informacji na temat ich działania. To, co już wiem, to:
- Zapisz zmienne modelu w pliku punktu kontrolnego (.ckpt) za pomocą a
tf.train.Saver()
i przywróć je później ( źródło ) - Zapisz model do pliku .pb i załaduj go z powrotem za pomocą
tf.train.write_graph()
itf.import_graph_def()
( źródło ) - Załaduj model z pliku .pb, przetrenuj go i wrzuć do nowego pliku .pb za pomocą Bazel ( źródło )
- Zablokuj wykres, aby zapisać wykres i wagi razem ( źródło )
- Służy
as_graph_def()
do zapisywania modelu, a dla wag / zmiennych mapowania ich na stałe ( źródło )
Nie udało mi się jednak wyjaśnić kilku pytań dotyczących tych różnych metod:
- Jeśli chodzi o pliki punktów kontrolnych, czy zapisują one tylko wyuczone wagi modelu? Czy pliki punktów kontrolnych mogą zostać załadowane do nowego programu i użyte do uruchomienia modelu, czy mogą po prostu służyć jako sposób na zapisanie wag w modelu w określonym czasie / etapie?
- W związku z tym
tf.train.write_graph()
, czy wagi / zmienne również są zapisane? - Jeśli chodzi o Bazel, czy może on zapisywać do / ładować z plików .pb tylko w celu ponownego przeszkolenia? Czy istnieje proste polecenie Bazel, które służy tylko do zrzucenia wykresu do pliku .pb?
- Jeśli chodzi o zamrażanie, czy zamrożony wykres można załadować za pomocą
tf.import_graph_def()
? - Wersja demonstracyjna systemu Android dla TensorFlow ładuje się w modelu Inception Google z pliku .pb. Gdybym chciał zastąpić własny plik .pb, jak bym to zrobił? Czy musiałbym zmienić kod / metody natywne?
- Jaka jest właściwie różnica między wszystkimi tymi metodami? Albo szerzej, jaka jest różnica między
as_graph_def()
/.ckpt/.pb?
Krótko mówiąc, szukam metody zapisywania zarówno wykresu (jak w przypadku różnych operacji itp.), Jak i jego wag / zmiennych do pliku, którego można następnie użyć do załadowania wykresu i wag do innego programu do użytku (niekoniecznie kontynuowanie / przekwalifikowanie).
Dokumentacja na ten temat nie jest prosta, więc wszelkie odpowiedzi / informacje byłyby bardzo mile widziane.
python
tensorflow
protocol-buffers
Technicolor
źródło
źródło
Odpowiedzi:
Istnieje wiele sposobów podejścia do problemu zapisywania modelu w TensorFlow, co może być nieco zagmatwane. Biorąc po kolei każde z pytań podrzędnych:
Pliki punktów kontrolnych (np produkowane przez wywołanie
saver.save()
natf.train.Saver
obiekcie) zawierają tylko ciężary oraz wszelkie inne zmienne zdefiniowane w tym samym programie. Aby użyć ich w innym programie, należy odtworzyć powiązaną strukturę wykresu (np. Uruchamiając kod w celu jego ponownego zbudowania lub wywołująctf.import_graph_def()
), która powie TensorFlow, co ma zrobić z tymi wagami. Zwróć uwagę, że wywołaniesaver.save()
również tworzy plik zawierający aMetaGraphDef
, który zawiera wykres i szczegóły dotyczące powiązania wag z punktu kontrolnego z tym wykresem. Więcej informacji znajdziesz w samouczku .tf.train.write_graph()
zapisuje tylko strukturę grafu; nie ciężary.Bazel nie jest związany z czytaniem ani pisaniem wykresów TensorFlow. (Być może źle zrozumiałem twoje pytanie: możesz to wyjaśnić w komentarzu.)
Zamrożony wykres można załadować za pomocą
tf.import_graph_def()
. W takim przypadku wagi są (zazwyczaj) osadzone na wykresie, więc nie ma potrzeby ładowania osobnego punktu kontrolnego.Główną zmianą byłoby zaktualizowanie nazw tensorów, które są wprowadzane do modelu, oraz nazw tensorów, które są pobierane z modelu. W wersji demonstracyjnej TensorFlow dla systemu Android odpowiadałoby to ciągom znaków
inputName
i,outputName
które są przekazywane doTensorFlowClassifier.initializeTensorFlow()
.Jest
GraphDef
to struktura programu, która zazwyczaj nie zmienia się w trakcie procesu szkolenia. Punkt kontrolny to migawka stanu procesu szkolenia, który zwykle zmienia się na każdym etapie procesu szkolenia. W rezultacie TensorFlow używa różnych formatów przechowywania tych typów danych, a niskopoziomowy interfejs API zapewnia różne sposoby ich zapisywania i ładowania. Biblioteki wyższego poziomu, takie jakMetaGraphDef
biblioteki, Keras i skflow, opierają się na tych mechanizmach, aby zapewnić wygodniejsze sposoby zapisywania i przywracania całego modelu.źródło
tf.train.write_graph()
a następnie go wykonać?GraphDef
zapisanych przeztf.train.write_graph()
, musisz również zapamiętać nazwy tensorów, które chcesz zasilić i pobrać podczas wykonywania wykresu (pozycja 5 powyżej).Możesz wypróbować następujący kod:
źródło