Regresja logistyczna: Scikit Learn vs Statsmodels

31

Próbuję zrozumieć, dlaczego wyniki regresji logistycznej tych dwóch bibliotek dają różne wyniki.

Używam zestawu danych z UCLA Idre poradnik , przewidywania admitna podstawie gre, gpai rank. rankjest traktowany jako zmienna kategorialna, dlatego najpierw jest konwertowany na zmienną fikcyjną rank_1. Dodano także kolumnę przechwytującą.

df = pd.read_csv("https://stats.idre.ucla.edu/stat/data/binary.csv")
y, X = dmatrices('admit ~ gre + gpa + C(rank)', df, return_type = 'dataframe')
X.head()
>  Intercept  C(rank)[T.2]  C(rank)[T.3]  C(rank)[T.4]  gre   gpa
0          1             0             1             0  380  3.61
1          1             0             1             0  660  3.67
2          1             0             0             0  800  4.00
3          1             0             0             1  640  3.19
4          1             0             0             1  520  2.93

# Output from scikit-learn
model = LogisticRegression(fit_intercept = False)
mdl = model.fit(X, y)
model.coef_
> array([[-1.35417783, -0.71628751, -1.26038726, -1.49762706,  0.00169198,
     0.13992661]]) 
# corresponding to predictors [Intercept, rank_2, rank_3, rank_4, gre, gpa]

# Output from statsmodels
logit = sm.Logit(y, X)
logit.fit().params
> Optimization terminated successfully.
     Current function value: 0.573147
     Iterations 6
Intercept      -3.989979
C(rank)[T.2]   -0.675443
C(rank)[T.3]   -1.340204
C(rank)[T.4]   -1.551464
gre             0.002264
gpa             0.804038
dtype: float64

Dane wyjściowe statsmodelssą takie same, jak pokazano na stronie idre, ale nie jestem pewien, dlaczego scikit-learn produkuje inny zestaw współczynników. Czy minimalizuje to jakąś inną funkcję utraty? Czy jest jakaś dokumentacja określająca wdrożenie?

hurrikale
źródło

Odpowiedzi:

28

Waszą wskazówką do rozwiązania tego problemu powinno być to, że oszacowania parametrów z oszacowania scikit-learn są jednakowo mniejsze pod względem wielkości niż odpowiednik statsmodels. Może to prowadzić do przekonania, że ​​scikit-learn stosuje pewną regularyzację parametrów. Możesz to potwierdzić, czytając dokumentację scikit-learn .

Nie ma sposobu, aby wyłączyć regularyzację w scikit-learn, ale możesz to uczynić nieskutecznym, ustawiając parametr strojenia Cna dużą liczbę. Oto jak to działa w twoim przypadku:

# module imports
from patsy import dmatrices
import pandas as pd
from sklearn.linear_model import LogisticRegression
import statsmodels.discrete.discrete_model as sm

# read in the data & create matrices
df = pd.read_csv("http://www.ats.ucla.edu/stat/data/binary.csv")
y, X = dmatrices('admit ~ gre + gpa + C(rank)', df, return_type = 'dataframe')

# sklearn output
model = LogisticRegression(fit_intercept = False, C = 1e9)
mdl = model.fit(X, y)
model.coef_

# sm
logit = sm.Logit(y, X)
logit.fit().params
tchakravarty
źródło
Dziękuję bardzo za wyjaśnienie! Przy tym uregulowanym wyniku próbowałem zduplikować wynik przy użyciu glmnetpakietu w języku R, ale nie mogłem uzyskać tego samego współczynnika. glmnet ma nieco inną funkcję kosztów w porównaniu do sklearn , ale nawet jeśli ustawić alpha=0w glmnet(czyli używać tylko L2-karny) i określić 1/(N*lambda)=C, nadal nie uzyskać ten sam wynik?
hurrikale
Moją intuicją jest to, że jeśli podzielę oba warunki funkcji kosztu glmnetprzez lambdai ustawię nową stałą pod względem czcionki logarytmu prawdopodobieństwa, która jest 1/(N*lambda)równa temu w sklearn, dwie funkcje kosztu staną się identyczne, czy coś mi brakuje?
hurrikale
@hurrikale Zadaj nowe pytanie i umieść je tutaj, a ja się obejrzę.
tchakravarty
Dzięki! Zadałem pytanie tutaj .
hurrikale
Myślę, że najlepszym sposobem na wyłączenie regulacji w scikit-learn jest ustawienie penalty='none'.
Nzbuu
3

Kolejna różnica polega na tym, że ustawiłeś fit_intercept = False, który faktycznie jest innym modelem. Widać, że Statsmodel obejmuje przechwytywanie. Brak przechwycenia z pewnością zmienia oczekiwane wagi funkcji. Spróbuj wykonać następujące czynności i zobaczyć, jak to się porównuje:

model = LogisticRegression(C=1e9)

Brian Dalessandro
źródło