Próbuję zrozumieć, dlaczego wyniki regresji logistycznej tych dwóch bibliotek dają różne wyniki.
Używam zestawu danych z UCLA Idre poradnik , przewidywania admit
na podstawie gre
, gpa
i rank
. rank
jest traktowany jako zmienna kategorialna, dlatego najpierw jest konwertowany na zmienną fikcyjną rank_1
. Dodano także kolumnę przechwytującą.
df = pd.read_csv("https://stats.idre.ucla.edu/stat/data/binary.csv")
y, X = dmatrices('admit ~ gre + gpa + C(rank)', df, return_type = 'dataframe')
X.head()
> Intercept C(rank)[T.2] C(rank)[T.3] C(rank)[T.4] gre gpa
0 1 0 1 0 380 3.61
1 1 0 1 0 660 3.67
2 1 0 0 0 800 4.00
3 1 0 0 1 640 3.19
4 1 0 0 1 520 2.93
# Output from scikit-learn
model = LogisticRegression(fit_intercept = False)
mdl = model.fit(X, y)
model.coef_
> array([[-1.35417783, -0.71628751, -1.26038726, -1.49762706, 0.00169198,
0.13992661]])
# corresponding to predictors [Intercept, rank_2, rank_3, rank_4, gre, gpa]
# Output from statsmodels
logit = sm.Logit(y, X)
logit.fit().params
> Optimization terminated successfully.
Current function value: 0.573147
Iterations 6
Intercept -3.989979
C(rank)[T.2] -0.675443
C(rank)[T.3] -1.340204
C(rank)[T.4] -1.551464
gre 0.002264
gpa 0.804038
dtype: float64
Dane wyjściowe statsmodels
są takie same, jak pokazano na stronie idre, ale nie jestem pewien, dlaczego scikit-learn produkuje inny zestaw współczynników. Czy minimalizuje to jakąś inną funkcję utraty? Czy jest jakaś dokumentacja określająca wdrożenie?
źródło
glmnet
pakietu w języku R, ale nie mogłem uzyskać tego samego współczynnika. glmnet ma nieco inną funkcję kosztów w porównaniu do sklearn , ale nawet jeśli ustawićalpha=0
wglmnet
(czyli używać tylko L2-karny) i określić1/(N*lambda)=C
, nadal nie uzyskać ten sam wynik?glmnet
przezlambda
i ustawię nową stałą pod względem czcionki logarytmu prawdopodobieństwa, która jest1/(N*lambda)
równa temu wsklearn
, dwie funkcje kosztu staną się identyczne, czy coś mi brakuje?penalty='none'
.Kolejna różnica polega na tym, że ustawiłeś fit_intercept = False, który faktycznie jest innym modelem. Widać, że Statsmodel obejmuje przechwytywanie. Brak przechwycenia z pewnością zmienia oczekiwane wagi funkcji. Spróbuj wykonać następujące czynności i zobaczyć, jak to się porównuje:
model = LogisticRegression(C=1e9)
źródło