RNN: kiedy stosować BPTT i / lub aktualizować wagi?

16

Próbuję zrozumieć ogólne zastosowanie RNN do znakowania sekwencji za pomocą (między innymi) artykułu Gravesa z 2005 r. Na temat klasyfikacji fonemów.

Podsumowując problem: Mamy duży zestaw szkoleniowy składający się z (wejściowych) plików audio z pojedynczych zdań i (wyjściowych) opatrzonych znakiem eksperckim czasów rozpoczęcia, czasów zatrzymania i etykiet dla poszczególnych fonemów (w tym kilku „specjalnych” fonemów, takich jak cisza, tak, że każda próbka w każdym pliku audio jest oznaczona jakimś symbolem fonemu).

Celem pracy jest zastosowanie RNN z komórkami pamięci LSTM w ukrytej warstwie do tego problemu. (Dla porównania stosuje kilka wariantów i kilka innych technik. W tej chwili interesuje mnie TYLKO jednokierunkowy LSTM, dla uproszczenia.)

Wydaje mi się, że rozumiem architekturę sieci: warstwa wejściowa odpowiadająca 10 ms oknom plików audio, wstępnie przetworzona w sposób standardowy dla dźwięku; ukryta warstwa komórek LSTM i warstwa wyjściowa z jednym kodowaniem wszystkich możliwych 61 symboli telefonu.

Wierzę, że rozumiem (zawiłe, ale proste) równania przejścia do przodu i do tyłu przez jednostki LSTM. Są tylko rachunkiem i regułą łańcucha.

Po kilkakrotnym przeczytaniu tego artykułu i kilku podobnych nie rozumiem, kiedy dokładnie zastosować algorytm propagacji wstecznej i kiedy dokładnie zaktualizować różne wagi w neuronach.

Istnieją dwie wiarygodne metody:

1) Ramka korekcji i aktualizacja

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
- Apply backpropagation to this frame's error
- Update weights accordingly
At end of sentence, reset memory
load another sentence and continue.

lub,

2) Modyfikacja zwrotna i aktualizacja:

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
At end of sentence:
- Apply backprop to average of sentence error function
- Update weights accordingly
- Reset memory
Load another sentence and continue.

Zauważ, że jest to ogólne pytanie na temat treningu RNN z wykorzystaniem dokumentu Gravesa jako spiczastego (i osobiście istotnego) przykładu: Czy podczas szkolenia RNN na sekwencjach stosuje się backprop na każdym etapie? Czy wagi są dostosowywane za każdym razem? Czy też, w luźnej analogii do treningu wsadowego na architekturach ściśle sprzężonych, czy błędy są kumulowane i uśredniane dla określonej sekwencji przed zastosowaniem aktualizacji backprop i aktualizacji wagi?

A może jestem jeszcze bardziej zdezorientowany niż myślę?

Novak
źródło

Odpowiedzi:

26

Zakładam, że mówimy o powtarzających się sieciach neuronowych (RNN), które wytwarzają dane wyjściowe na każdym etapie (jeśli dane wyjściowe są dostępne tylko na końcu sekwencji, sensowne jest, aby uruchomić backprop na końcu). RNN w tym ustawieniu są często trenowane przy użyciu skróconej propagacji wstecznej w czasie (BPTT), działającej sekwencyjnie na „fragmentach” sekwencji. Procedura wygląda następująco:

  1. Przekazywanie w przód: przechodzenie przez kolejne kroki czasowe , obliczanie stanów wejściowych, ukrytych i wyjściowych.k1
  2. Oblicz stratę zsumowaną z poprzednich kroków czasowych (patrz poniżej).
  3. Przebieg wsteczny: oblicz gradient gradientu wrt wszystkich parametrów, kumulując się w poprzednich czasu (wymaga to zapisania wszystkich aktywacji dla tych kroków czasu). Przycinaj gradienty, aby uniknąć problemu z eksplodującym gradientem (zdarza się rzadko).k2)
  4. Zaktualizuj parametry (dzieje się to raz na porcję, nie przyrostowo za każdym razem).
  5. W przypadku przetwarzania wielu fragmentów o dłuższej sekwencji, zapisz stan ukryty na ostatnim etapie czasu (będzie on użyty do zainicjowania stanu ukrytego na początek następnego fragmentu). Jeśli doszliśmy do końca sekwencji, zresetuj pamięć / stan ukryty i przejdź na początek następnej sekwencji (lub początek tej samej sekwencji, jeśli jest tylko jedna).
  6. Powtórz od kroku 1.

Sposób zsumowania straty zależy od i . Na przykład, gdy , strata jest sumowana w ciągu ostatnich kroków czasowych, ale procedura jest inna, gdy (patrz Williams i Peng 1990).k1k2)k1=k2)k1=k2)k2)>k1

Obliczenia gradientowe i aktualizacje są wykonywane co kroków czasu, ponieważ jest to obliczeniowo tańsze niż aktualizacja za każdym razem. Aktualizowanie wiele razy na sekwencję (tj. Ustawienie mniejszej niż długość sekwencji) może przyspieszyć trening, ponieważ aktualizacje wagi są częstsze.k1k1

jest wykonywana tylko dla kroków czasowych ponieważ jest obliczeniowo tańsza niż propagacja z powrotem na początek sekwencji (co wymagałoby przechowywania i wielokrotnego przetwarzania wszystkich kroków czasowych). Gradienty obliczone w ten sposób są przybliżeniem „prawdziwego” gradientu obliczonego dla wszystkich kroków czasowych. Ale z powodu znikającego problemu z gradientem gradienty będą miały tendencję do zbliżania się do zera po pewnej liczbie kroków czasowych; rozmnażanie poza tym limitem nie przyniosłoby żadnych korzyści. Ustawienie zbyt krótkiego może ograniczyć skalę czasową, w której sieć może się uczyć. Pamięć sieci nie jest jednak ograniczona do czasowych, ponieważ jednostki ukryte mogą przechowywać informacje poza tym okresem (npk2)k2)k2)).

Poza względami obliczeniowymi, odpowiednie ustawienia dla i zależą od statystyki danych (np. Skala czasowa struktur, które są istotne dla uzyskania dobrych wyników). Prawdopodobnie zależą również od szczegółów sieci. Na przykład istnieje wiele architektur, sztuczek inicjalizacyjnych itp. Zaprojektowanych w celu złagodzenia problemu rozpadającego się gradientu.k1k2)

Twoja opcja 1 („cofnij ramkę”) odpowiada ustawieniu na a liczby kroków czasowych od początku zdania do bieżącego punktu. Opcja 2 („zdanie zwrotne”) odpowiada ustawieniu zarówno długości jak i . Oba są poprawnymi podejściami (z uwzględnieniem obliczeń / wydajności, jak powyżej; # 1 byłby dość intensywny obliczeniowo w przypadku dłuższych sekwencji). Żadne z tych podejść nie byłoby nazywane „obciętym”, ponieważ propagacja wsteczna występuje w całej sekwencji. Możliwe są inne ustawienia i ; Poniżej wymienię kilka przykładów.k11k2)k1k2)k1k2)

Odnośniki opisujące skrócony BPTT (procedura, motywacja, kwestie praktyczne):

  • Sutskever (2013) . Szkolenie nawracających sieci neuronowych.
  • Mikolov (2012) . Statystyczne modele językowe oparte na sieciach neuronowych.
    • Używając waniliowych numerów RNN do przetwarzania danych tekstowych jako sekwencji słów, zaleca ustawienie na 10-20 słów i na 5 słówk1k2)
    • Wykonywanie wielu aktualizacji na sekwencję (tj. mniej niż długość sekwencji) działa lepiej niż aktualizacja na końcu sekwencjik1
    • Przeprowadzanie aktualizacji raz na porcję jest lepsze niż przyrostowe (co może być niestabilne)
  • Williams i Peng (1990) . Wydajny algorytm gradientowy do szkolenia on-line powtarzających się trajektorii sieciowych.
    • Oryginalna (?) Propozycja algorytmu
    • Dyskutują Wybór i (które nazywają i ). Rozważają tylko .k1k2)hhk2)k1
    • Uwaga: używają wyrażenia „BPTT (h; h”) ”lub„ ulepszonego algorytmu ”, aby odnieść się do tego, co inne odniesienia nazywają„ obciętym BPTT ”. Używają wyrażenia „skrócony BPTT”, co oznacza szczególny przypadek, w którym .k1=1

Inne przykłady z użyciem obciętego BPTT:

  • (Karpathy 2015). char-rnn.
    • Opis i kod
    • Waniliowe przetwarzanie tekstu RNN dokumentuje po jednym znaku na raz. Przeszkolony do przewidywania następnej postaci. znaków. Sieć służyła do generowania nowego tekstu w stylu dokumentu szkoleniowego, z zabawnymi wynikami.k1=k2)=25
  • Graves (2014) . Generowanie sekwencji za pomocą rekurencyjnych sieci neuronowych.
    • Zobacz rozdział dotyczący generowania symulowanych artykułów z Wikipedii. Sieć LSTM przetwarzająca dane tekstowe jako sekwencję bajtów. Przeszkolony do przewidywania następnego bajtu. bajtów. Reset pamięci LSTM co bajtów.k1=k2)=10010,000
  • Sak i in. (2014) . Rekurencyjne architektury sieci neuronowych oparte na pamięci krótkoterminowej do rozpoznawania mowy dużego słownictwa.
    • Zmodyfikowane sieci LSTM, przetwarzanie sekwencji cech akustycznych. .k1=k2)=20
  • Ollivier i in. (2015) . Szkolenie powtarzających się sieci online bez cofania się.
    • Celem tego artykułu było zaproponowanie innego algorytmu uczenia się, ale porównano go ze skróconym BPTT. Użyto waniliowych numerów RNN do przewidywania sekwencji symboli. Wspominając o tym tutaj, aby powiedzieć, że użyli .k1=k2)=15
  • Hochreiter i Schmidhuber (1997) . Długotrwała pamięć krótkotrwała.
    • Opisują zmodyfikowaną procedurę dla LSTM
user20160
źródło
To wybitna odpowiedź i chciałbym mieć taką pozycję na tym forum, aby przyznać jej sporą nagrodę. Szczególnie przydatne są konkretne omówienie k1 vs k2 w celu kontekstualizacji moich dwóch przypadków w kontekście bardziej ogólnego użycia i ich liczbowych przykładów.
Novak