Średni ROC dla powtarzanej 10-krotnej walidacji krzyżowej z oszacowaniami prawdopodobieństwa

15

Planuję użyć powtarzanej (10 razy) stratyfikacji 10-krotnej krzyżowej walidacji na około 10 000 przypadków przy użyciu algorytmu uczenia maszynowego. Za każdym razem powtórzenie zostanie wykonane z innym losowym ziarnem.

W tym procesie tworzę 10 przypadków oszacowań prawdopodobieństwa dla każdego przypadku. 1 przypadek oszacowania prawdopodobieństwa dla każdego z 10 powtórzeń 10-krotnej walidacji krzyżowej

Czy mogę uśrednić 10 prawdopodobieństw dla każdego przypadku, a następnie utworzyć nową średnią krzywą ROC (reprezentującą wyniki powtarzanego 10-krotnego CV), którą można porównać z innymi krzywymi ROC za pomocą porównań w parach?

użytkownik97953
źródło

Odpowiedzi:

13

Z twojego opisu wydaje się to mieć sens: nie tylko możesz obliczyć średnią krzywą ROC, ale także wariancję wokół niej, aby zbudować przedziały ufności. Powinien dać ci wyobrażenie o tym, jak stabilny jest twój model.

Na przykład:

wprowadź opis zdjęcia tutaj

Tutaj umieszczam poszczególne krzywe ROC, a także krzywą średnią i przedziały ufności. Są obszary, w których krzywe się zgadzają, więc mamy mniejszą wariancję, i są obszary, w których się nie zgadzają.

W przypadku powtarzanego CV możesz po prostu powtórzyć go wiele razy i uzyskać całkowitą średnią dla wszystkich poszczególnych fałd:

wprowadź opis zdjęcia tutaj

Jest dość podobny do poprzedniego obrazu, ale daje bardziej stabilne (tj. Wiarygodne) oszacowania średniej i wariancji.

Oto kod, aby uzyskać fabułę:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

W przypadku powtarzającego się CV:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Źródło inspiracji: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html

Aleksiej Grigoriew
źródło
3

Średnie prawdopodobieństwo nie jest poprawne, ponieważ nie reprezentuje prognoz, które próbujesz zweryfikować, i wiąże się z zanieczyszczeniem próbek walidacyjnych.

Należy zauważyć, że może być wymagane 100 powtórzeń 10-krotnej walidacji krzyżowej w celu uzyskania odpowiedniej precyzji. Lub użyj bootstrapu optymizmu Efron-Gong, który wymaga mniejszej liczby iteracji dla tej samej precyzji (patrz np. Funkcje rmspakietu R validate).

Krzywe ROC nie są w żaden sposób wnikliwe dla tego problemu. Użyj właściwego wyniku dokładności i dołącz do niego indeks (prawdopodobieństwo zgodności; AUROC), z którym łatwiej jest sobie poradzić niż krzywa, ponieważ można go łatwo i szybko obliczyć przy użyciu statystyki Wilcoxona-Manna-Whitneya.c

Frank Harrell
źródło
Czy mógłbyś bardziej szczegółowo wyjaśnić, dlaczego uśrednianie jest nieprawidłowe?
DataD'oh,
Już stwierdzono. Musisz zweryfikować miarę, której będziesz używać w terenie.
Frank Harrell,