Jak wyświetlić listę wszystkich używanych operacji w Tensorflow SavedModel?

10

Jeśli tensorflow.saved_model.savezapiszę mój model za pomocą funkcji w formacie SavedModel, jak mogę pobrać, które operacje Tensorflow zostaną użyte w tym modelu później. Ponieważ model można przywrócić, operacje te są przechowywane na wykresie, domyślam się, że w saved_model.pbpliku. Jeśli załaduję ten protobuf (a więc nie cały model), część protobufa w bibliotece zawiera je, ale nie jest to na razie udokumentowane i oznaczone jako funkcja eksperymentalna. Modele utworzone w Tensorflow 1.x nie będą miały tej części.

Jaki jest zatem szybki i niezawodny sposób na uzyskanie listy używanych operacji (takich jak MatchingFileslub WriteFile) z modelu w formacie SavedModel?

W tej chwili mogę zamrozić wszystko, podobnie jak tensorflowjs-converterrobi. Sprawdzają również obsługiwane operacje. To obecnie nie działa, gdy LSTM jest w modelu, patrz tutaj . Czy istnieje lepszy sposób, aby to zrobić, ponieważ operatorzy zdecydowanie tam są?

Przykładowy model:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Oczekiwano w wyniku wszystkich operacji, zawierających w tym przypadku co najmniej:

  • ReadFilejak opisano tutaj
  • ...
próbników
źródło
1
Trudno powiedzieć dokładnie, czego chcesz, co to saved_model.pbjest tf.GraphDef, czy SavedModelwiadomość typu protobuf? Jeśli masz tf.GraphDefsprawdzone gd, możesz uzyskać listę używanych operacji z sorted(set(n.op for n in gd.node)). Jeśli masz załadowany model, możesz to zrobić sorted(set(op.type for op in tf.get_default_graph().get_operations())). Jeśli jest to SavedModel, możesz tf.GraphDefz niego uzyskać (np saved_model.meta_graphs[0].graph_def.).
jdehesa
Chcę pobrać operacje z zapisanego modelu SavedModel. Tak więc ostatnia opcja, którą opisujesz. Jaka jest saved_modelzmienna w twoim ostatnim przykładzie? Wynik tf.saved_model.load('/path/to/model')lub ładowanie protobufa pliku save_model.pb.
sampers

Odpowiedzi:

1

Jeśli saved_model.pbjest to SavedModelkomunikat Protobuf, operacje są pobierane bezpośrednio z niego. Załóżmy, że tworzymy model w następujący sposób:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Możemy teraz znaleźć operacje używane przez ten model w następujący sposób:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin
Jdehesa
źródło
Próbowałem czegoś takiego, ale niestety nie spełnia moich oczekiwań: Powiedzmy, że mam model, który to robi: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')W takim razie plik ReadFile Op wymieniony tutaj znajduje się, ale nie jest drukowany.
sampers
1
@sampers Zedytowałem odpowiedź z przykładem, który sugerujesz. Dostaję ReadFileoperację na wyjściu. Czy to możliwe, że w twoim przypadku ta operacja nie znajduje się pomiędzy wejściem a wyjściem zapisanego modelu? W takim przypadku myślę, że może zostać przycięty.
jdehesa
Rzeczywiście z danym modelem działa. Niestety dla modułu wykonanego w TF2 tak nie jest. Jeśli utworzę moduł tf.Module z 1 funkcją z adnotacją file_nameargumentu @tf.function, zawierającą wywołania wymienione w poprzednim komentarzu, daje następującą listę:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers
dodał model do mojego pytania
sampers
@sampers Zaktualizowałem swoją odpowiedź. Wcześniej korzystałem z TF 1.x, nie znałem zmian w obiektach definicji wykresów w TF 2.x, myślę, że odpowiedź obejmuje teraz wszystko w zapisanym modelu. Myślę, że operacje odpowiadające napisanej funkcji Pythona są w saved_model.meta_graphs[0].graph_def.library.function[0]( node_defkolekcja w tym obiekcie funkcji).
jdehesa