W jaki sposób prosty model regresji logistycznej osiąga 92% dokładność klasyfikacji na MNIST?

68

Mimo że wszystkie obrazy w zestawie danych MNIST są wyśrodkowane, z podobną skalą i odkryte bez rotacji, mają znaczącą odmianę pisma ręcznego, która zastanawia mnie, w jaki sposób model liniowy osiąga tak wysoką dokładność klasyfikacji.

O ile jestem w stanie sobie wyobrazić, biorąc pod uwagę znaczną różnorodność pisma ręcznego, cyfry powinny być liniowo nierozdzielne w przestrzeni 784 wymiarów, tj. Powinna istnieć mała złożona (choć niezbyt złożona) nieliniowa granica oddzielająca różne cyfry , podobnie jak w dobrze cytowanym przykładzie którym klas dodatnich i ujemnych nie można oddzielić żadnym liniowym klasyfikatorem. Wydaje mi się zaskakujące, że regresja logistyczna wielu klas zapewnia tak wysoką dokładność przy całkowicie liniowych cechach (bez cech wielomianowych).XOR

Na przykład, biorąc pod uwagę dowolny piksel na obrazie, różne odręczne odmiany cyfr i mogą spowodować, że piksel ten zostanie podświetlony lub nie. Dlatego przy zestawie wyuczonych wag każdy piksel może sprawić, że cyfra będzie wyglądać zarówno jako jak i . Tylko z kombinacją wartości pikseli powinno być możliwe stwierdzenie, czy cyfra jest liczbą czy . Dotyczy to większości par cyfr. Jak więc regresja logistyczna, która ślepo opiera swoją decyzję niezależnie na wszystkich wartościach pikseli (bez uwzględnienia jakichkolwiek zależności między pikselami), jest w stanie osiągnąć tak wysokie dokładności.232323

Wiem, że gdzieś się mylę lub po prostu przeceniam zmienność obrazów. Byłoby jednak wspaniale, gdyby ktoś mógł mi pomóc w intuicji, w jaki sposób cyfry można „prawie” rozdzielić liniowo.

Nitish Agarwal
źródło
Zajrzyj do podręcznika Statystyka statystyczna ze sparsity: Lasso i uogólnienia 3.3.1 Przykład: Cyfry odręczne web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian
Byłem ciekawy: jak dobrze radzi sobie z problemem coś takiego jak karany model liniowy (tj. Glmnet)? O ile pamiętam, zgłaszasz nieokreśloną dokładność poza próbą.
Cliff AB

Odpowiedzi:

86

tl; dr Mimo że jest to zestaw danych klasyfikacji obrazów, pozostaje on bardzo łatwym zadaniem, dla którego można łatwo znaleźć bezpośrednie odwzorowanie danych wejściowych na przewidywania.


Odpowiedź:

To bardzo interesujące pytanie, a dzięki prostocie regresji logistycznej faktycznie można znaleźć odpowiedź.

Regresja logistyczna polega na tym, że dla każdego obrazu można zaakceptować dane wejściowe i pomnożyć je przez wagi, aby wygenerować prognozę. Interesujące jest to, że ze względu na bezpośrednie mapowanie między danymi wejściowymi i wyjściowymi (tj. Brak ukrytej warstwy) wartość każdej wagi odpowiada temu, ile każdego z danych wejściowych jest branych pod uwagę przy obliczaniu prawdopodobieństwa każdej klasy. Teraz, biorąc wagi dla każdej klasy i przekształcając je w (tj. Rozdzielczość obrazu), możemy stwierdzić, które piksele są najważniejsze dla obliczeń każdej klasy .78478428×28

Zauważ ponownie, że są to ciężary .

Teraz spójrz na powyższy obraz i skup się na pierwszych dwóch cyfrach (tj. Zero i jedna). Niebieskie wagi oznaczają, że intensywność tego piksela ma duży udział w tej klasie, a czerwone wartości oznaczają, że ma negatywny wpływ.

Teraz wyobraź sobie, jak osoba rysuje ? Rysuje między nimi okrągły kształt, który jest pusty. To właśnie nabierały ciężary. W rzeczywistości, jeśli ktoś narysuje środek obrazu, liczy się on ujemnie jako zero. Aby rozpoznać zera, nie potrzebujesz skomplikowanych filtrów i funkcji wysokiego poziomu. Możesz po prostu spojrzeć na narysowane lokalizacje pikseli i ocenić według tego.0

To samo dotyczy . Zawsze ma prostą pionową linię na środku obrazu. Wszystko inne liczy się negatywnie.1

Pozostałe cyfry są nieco bardziej skomplikowane, ale przy niewielkiej wyobraźni widać , , i . Reszta liczb jest nieco trudniejsza, co faktycznie ogranicza regresję logistyczną przed osiągnięciem lat 90-tych.2378

Dzięki temu widać, że regresja logistyczna ma bardzo duże szanse na uzyskanie dużej liczby zdjęć i dlatego osiąga tak wysokie wyniki.


Kod do odtworzenia powyższego rysunku jest nieco przestarzały, ale proszę bardzo:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
Djib2011
źródło
12
Dzięki za ilustrację. Te obrazy wag pokazują, jak duża jest dokładność. Mnożenie kropki ręcznie napisanego obrazu cyfrowego z obrazem wagi odpowiadającym prawdziwej etykiecie obrazu „wydaje się” być najwyższe w porównaniu do produktu kropkowego z innymi etykietami wagi dla większości (nadal dla mnie 92% wygląda bardzo podobnie) zdjęć w MNIST. Mimo to trochę zaskakujące jest to, że i lub i rzadko są błędnie klasyfikowane jako siebie przy badaniu matrycy pomieszania. Tak czy inaczej, to jest to. Dane nigdy nie kłamią. :)2378
Nitish Agarwal
13
Oczywiście pomaga to, że próbki MNIST są wyśrodkowane, skalowane i znormalizowane kontrastowo, zanim klasyfikator je zobaczy. Nie musisz odpowiadać na pytania typu „co jeśli krawędź zera faktycznie przechodzi przez środek pola?” ponieważ procesor wstępny przeszedł już długą drogę, aby wszystkie zera wyglądały tak samo.
hobbs
1
@EricDuminil Dodałem wyróżnienie do skryptu z twoją sugestią. Wielkie dzięki za wkład! : D
Djib2011
1
@NitishAgarwal, jeśli uważasz, że ta odpowiedź jest odpowiedzią na twoje pytanie, rozważ oznaczenie go jako takiej.
sintax
11
Dla kogoś, kto jest zainteresowany, ale niezbyt zaznajomiony z tego rodzaju przetwarzaniem, ta odpowiedź stanowi fantastyczny, intuicyjny przykład mechaniki.
Chrylis -on strike-