Czy istnieje możliwość zmiany metryki używanej przez wywołanie zwrotne Early Stopping w Keras?

13

Podczas korzystania z wywołania zwrotnego wczesnego zatrzymania w Keras trening zatrzymuje się, gdy niektóre wskaźniki (zwykle utrata sprawdzania poprawności) nie rosną. Czy istnieje sposób na użycie innej miary (takiej jak precyzja, odwołanie, miara f) zamiast utraty sprawdzania poprawności? Wszystkie przykłady, które do tej pory widziałem, są podobne do tego: callbacks.EarlyStopping (monitor = 'val_loss', cierpliwość = 5, verbose = 0, mode = 'auto')

P.Joseph
źródło

Odpowiedzi:

11

Możesz użyć dowolnej funkcji metrycznej określonej podczas kompilacji modelu.

Załóżmy, że masz następującą funkcję metryczną:

def my_metric(y_true, y_pred):
     return some_metric_computation(y_true, y_pred)

Jedynym wymaganiem dla tej funkcji jest przyjęcie prawdziwej yi przewidywanej y.

Podczas kompilowania modelu podajesz tę metrykę, podobnie jak określasz metryki wbudowane, takie jak „dokładność”:

model.compile(metrics=['accuracy', my_metric], ...)

Zauważ, że używamy nazwy funkcji my_metric bez „” (w przeciwieństwie do wbudowanej „dokładności”).

Następnie, jeśli zdefiniujesz EarlyStopping, po prostu użyj nazwy funkcji (tym razem z ''):

EarlyStopping(monitor='my_metric', mode='min')

Upewnij się, że określono tryb (min, jeśli niższa jest lepsza, maks., Jeśli wyższa jest lepsza).

Możesz go używać tak, jak z dowolnych wbudowanych danych. Prawdopodobnie działa to również z innymi wywołaniami zwrotnymi, takimi jak ModelCheckpoint (ale tego nie testowałem). Wewnętrznie Keras dodaje nową metrykę do listy metryk dostępnych dla tego modelu, używając nazwy funkcji.

Jeśli określisz dane do sprawdzania poprawności w swoim modelu.fit (...), możesz także użyć ich do EarlyStopping, używając „val_my_metric”.

Michał
źródło
3

Oczywiście, po prostu stwórz swój własny!

class EarlyStopByF1(keras.callbacks.Callback):
    def __init__(self, value = 0, verbose = 0):
        super(keras.callbacks.Callback, self).__init__()
        self.value = value
        self.verbose = verbose


    def on_epoch_end(self, epoch, logs={}):
         predict = np.asarray(self.model.predict(self.validation_data[0]))
         target = self.validation_data[1]
         score = f1_score(target, prediction)
         if score > self.value:
            if self.verbose >0:
                print("Epoch %05d: early stopping Threshold" % epoch)
            self.model.stop_training = True


callbacks = [EarlyStopByF1(value = .90, verbose =1)]
model.fit(X, y, batch_size = 32, nb_epoch=nb_epoch, verbose = 1, 
validation_data(X_val,y_val), callbacks=callbacks)

Nie przetestowałem tego, ale powinien to być ogólny smak tego, jak sobie z tym radzisz. Jeśli to nie zadziała, daj mi znać, a spróbuję ponownie w weekend. Zakładam również, że masz już zaimplementowany własny wynik F1. Jeśli nie tylko importuj do sklearn.

Cylinder
źródło
+1 Nadal działa od 02.11.2020 przy użyciu najnowszych Keras i Python 3.7
Austin