Próbuję zrozumieć, co robi opakowanie TimeDistributed w Keras.
Dostaję, że TimeDistributed „stosuje warstwę do każdego czasowego wycinka danych wejściowych”.
Ale przeprowadziłem pewien eksperyment i otrzymałem wyniki, których nie mogę zrozumieć.
Krótko mówiąc, w połączeniu z warstwą LSTM, TimeDistributed i po prostu Dense dają takie same wyniki.
model = Sequential()
model.add(LSTM(5, input_shape = (10, 20), return_sequences = True))
model.add(TimeDistributed(Dense(1)))
print(model.output_shape)
model = Sequential()
model.add(LSTM(5, input_shape = (10, 20), return_sequences = True))
model.add((Dense(1)))
print(model.output_shape)
Dla obu modeli otrzymałem kształt wyjściowy (Brak, 10, 1) .
Czy ktoś może wyjaśnić różnicę między warstwą TimeDistributed i Dense po warstwie RNN?
python
machine-learning
keras
neural-network
deep-learning
Buomsoo Kim
źródło
źródło
Dense
warstwą spłaszczającą dane wejściowe, a następnie przekształcającą, a więc łączącą różne kroki czasowe i posiadającą więcej parametrów, orazTimeDistributed
utrzymującą przedziały czasowe oddzielone (stąd mniej parametrów). W twoim przypadkuDense
powinien był mieć 500 parametrów,TimeDistributed
tylko 50Odpowiedzi:
W
keras
- przy budowaniu modelu sekwencyjnego - zwykle drugi wymiar (jeden po wymiarze przykładowym) - jest powiązany ztime
wymiarem. Oznacza to, że jeśli na przykład dane są5-dim
z(sample, time, width, length, channel)
tobą, możesz zastosować warstwę splotową przy użyciuTimeDistributed
(co ma zastosowanie do4-dim
with(sample, width, length, channel)
) wzdłuż wymiaru czasu (nakładając tę samą warstwę na każdy wycinek czasu) w celu uzyskania5-d
wyniku.Przypadek
Dense
jest taki, żekeras
od wersji 2.0Dense
jest domyślnie stosowany tylko do ostatniego wymiaru (np. Jeśli zastosujeszDense(10)
dane wejściowe z kształtem(n, m, o, p)
, otrzymasz dane wyjściowe z kształtem(n, m, o, 10)
), więc w twoim przypadkuDense
iTimeDistributed(Dense)
są równoważne.źródło
Input
tensora, czy jest jakaś różnica w porównaniu z wykonaniemmap
modelu zastosowanego do listy zawierającej każdy wycinekInput
?