Scikit przewiduje interpretację wyników wyjściowych

12

Pracuję z biblioteką scikit-learn w Pythonie. W poniższym kodzie przewiduję prawdopodobieństwo, ale nie wiem, jak odczytać wynik.

Testowanie danych

from sklearn.ensemble import RandomForestClassifier as RF
from sklearn import cross_validation

X = np.array([[5,5,5,5],[10,10,10,10],[1,1,1,1],[6,6,6,6],[13,13,13,13],[2,2,2,2]])
y = np.array([0,1,1,0,1,2])

Podziel zestaw danych

X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.5, random_state=0) 

Oblicz prawdopodobieństwo

clf = RF()
clf.fit(X_train,y_train)
pred_pro = clf.predict_proba(X_test)
print pred_pro

Wyjście

[[ 1.  0.]
 [ 1.  0.]
 [ 0.  1.]]

Lista X_test zawiera 3 tablice (mam 6 próbek i rozmiar_testu = 0,5), więc dane wyjściowe również mają 3.

Ale przewiduję 3 wartości (0,1,2), więc dlaczego otrzymuję tylko 2 elementy w każdej tablicy?

Jak mam odczytać wynik?

Zauważyłem również, że kiedy modyfikuję liczbę różnych wartości w y, liczba kolumn w danych wyjściowych jest zawsze różną liczbą y -1.

HonzaB
źródło
Witamy w CrossValidated. Widziałeś moją odpowiedź poniżej? Jeśli to rozwiązało twoje pytanie, zaznacz je jako poprawną odpowiedź. W przeciwnym razie daj mi znać, czego brakuje, a ja postaram się to dla ciebie wyjaśnić.
Ben

Odpowiedzi:

5

Spójrz na y_train. Jest array([0, 0, 1]). Oznacza to, że twój podział nie wziął próbki, gdzie y = 2. Twój model nie ma pojęcia, że ​​istnieje klasa y = 2.

Potrzebujesz więcej próbek, aby zwrócić coś znaczącego.

Sprawdź także dokumentację, aby zrozumieć, jak interpretować dane wyjściowe.

Ben
źródło
1
To jest poprawne. Jeśli ustawisz y = np.array([0,2,1,0,1,2])i random_state=2zobaczysz teraz 3 kolumny wyników
tdc
Odpowiedź rozwiązała moje pytanie. Dziękuję Ci bardzo. A w jakiej kolejności są kolumny? Zawsze rośnie?
HonzaB,
Uruchom clf.classes_. Kolumny będą w tej kolejności.
Ben
Po prostu tak: clf.fit(X_train,y_train).classes_?
HonzaB
1
Myślę, że to zadziała, ale możesz po prostu biegać clf.classes_ po bieguclf.fit(X_train,y_train)
Ben