Mam następujący CNN:
- Zaczynam od obrazu wejściowego o rozmiarze 5x5
- Następnie stosuję splot za pomocą jądra 2x2 i stride = 1, który tworzy mapę cech o rozmiarze 4x4.
- Następnie stosuję maksymalne łączenie 2x2 z krokiem = 2, co zmniejsza mapę obiektów do rozmiaru 2x2.
- Następnie stosuję sigmoid logistyczny.
- Następnie jedna w pełni połączona warstwa z 2 neuronami.
- I warstwa wyjściowa.
Dla uproszczenia załóżmy, że wykonałem już przejście do przodu i obliczyłem δH1 = 0,25 i δH2 = -0,15
Tak więc po zakończeniu pełnego przejścia do przodu i częściowo ukończonego przejścia do tyłu moja sieć wygląda następująco:
Następnie obliczam delty dla warstwy nieliniowej (sigmoid logistyczny):
Następnie propaguję delty na warstwę 4x4 i ustawiam wszystkie wartości, które zostały odfiltrowane przez maksymalne pule na 0, a mapa gradientu wygląda następująco:
Jak mogę stamtąd zaktualizować wagi jądra? A jeśli moja sieć miała inną warstwę splotową przed 5x5, jakich wartości powinienem użyć, aby zaktualizować jej wagi jądra? I ogólnie, czy moje obliczenia są prawidłowe?
machine-learning
convnet
backpropagation
cnn
kernel
koryakinp
źródło
źródło
Odpowiedzi:
Splot wykorzystuje zasadę podziału masy, która znacznie skomplikuje matematykę, ale spróbujmy przejść przez chwasty. Większość moich wyjaśnień czerpię z tego źródła .
Przekaż do przodu
Jak zauważyłeś, przejście do przodu warstwy splotowej można wyrazić jako
Propagacja wsteczna
Zakładając, że używasz średniego błędu kwadratu (MSE) zdefiniowanego jako
chcemy ustalić
To iteruje całą przestrzeń wyjściową, określa błąd, który przyczynia się do produkcji, a następnie określa współczynnik udziału ciężaru jądra w odniesieniu do tej produkcji.
Nazwijmy przyczynę błędu błędem przestrzeni wyjściowej dla uproszczenia i śledzenia błędu propagowanego wstecz,
Wkład odważników
Splot określa się jako
a zatem,
Wróćmy do naszego terminu błędu
Spadek gradientu stochastycznego
Obliczmy niektóre z nich
Daj mi znać, jeśli są błędy w pochodnej.
Aktualizacja: poprawiony kod
źródło
gradient = signal.convolve2d(np.rot90(np.rot90(d)), o, 'valid')