Ten artykuł na temat Adaboost zawiera pewne sugestie i kod (strona 17) dotyczący rozszerzenia modeli 2-klasowych na problemy klasy K. Chciałbym uogólnić ten kod, tak że mogę łatwo podłączyć różne modele 2-klasowe i porównać wyniki. Ponieważ większość modeli klasyfikacji ma interfejs formuły i predict
metodę, niektóre z nich powinny być stosunkowo łatwe. Niestety nie znalazłem standardowego sposobu wyodrębnienia prawdopodobieństw klasowych z modeli 2-klasowych, więc każdy model będzie wymagał niestandardowego kodu.
Oto funkcja, którą napisałem, aby podzielić problem klasy K na problemy klasy 2 i zwrócić modele K:
oneVsAll <- function(X,Y,FUN,...) {
models <- lapply(unique(Y), function(x) {
name <- as.character(x)
.Target <- factor(ifelse(Y==name,name,'other'), levels=c(name, 'other'))
dat <- data.frame(.Target, X)
model <- FUN(.Target~., data=dat, ...)
return(model)
})
names(models) <- unique(Y)
info <- list(X=X, Y=Y, classes=unique(Y))
out <- list(models=models, info=info)
class(out) <- 'oneVsAll'
return(out)
}
Oto metoda przewidywania, którą napisałem w celu iteracji każdego modelu i wykonania prognoz:
predict.oneVsAll <- function(object, newX=object$info$X, ...) {
stopifnot(class(object)=='oneVsAll')
lapply(object$models, function(x) {
predict(x, newX, ...)
})
}
I na koniec, oto funkcja normalizacji data.frame
przewidywanych prawdopodobieństw i sklasyfikowania przypadków. Zwróć uwagę, że od Ciebie zależy zbudowanie kolumny K data.frame
prawdopodobieństw z każdego modelu, ponieważ nie ma jednolitego sposobu wyodrębnienia prawdopodobieństw klasowych z modelu 2-klasowego:
classify <- function(dat) {
out <- dat/rowSums(dat)
out$Class <- apply(dat, 1, function(x) names(dat)[which.max(x)])
out
}
Oto przykład z użyciem adaboost
:
library(ada)
library(caret)
X <- iris[,-5]
Y <- iris[,5]
myModels <- oneVsAll(X, Y, ada)
preds <- predict(myModels, X, type='probs')
preds <- data.frame(lapply(preds, function(x) x[,2])) #Make a data.frame of probs
preds <- classify(preds)
>confusionMatrix(preds$Class, Y)
Confusion Matrix and Statistics
Reference
Prediction setosa versicolor virginica
setosa 50 0 0
versicolor 0 47 2
virginica 0 3 48
Oto przykład użycia lda
(wiem, że lda może obsłużyć wiele klas, ale to tylko przykład):
library(MASS)
myModels <- oneVsAll(X, Y, lda)
preds <- predict(myModels, X)
preds <- data.frame(lapply(preds, function(x) x[[2]][,1])) #Make a data.frame of probs
preds <- classify(preds)
>confusionMatrix(preds$Class, Y)
Confusion Matrix and Statistics
Reference
Prediction setosa versicolor virginica
setosa 50 0 0
versicolor 0 39 5
virginica 0 11 45
Funkcje te powinny działać dla każdego modelu 2-klasowego z interfejsem formuły i predict
metodą. Zauważ, że musisz ręcznie podzielić komponenty X i Y, co jest trochę brzydkie, ale pisanie interfejsu formuły jest obecnie poza mną.
Czy takie podejście ma sens dla wszystkich? Czy jest jakiś sposób, aby go ulepszyć, czy też istnieje istniejący pakiet do rozwiązania tego problemu?
car
lub jeden z*lab
pakietów) zapewniłby taką funkcję. Przepraszam, nie mogę pomóc. Przeczytałem trochę o tym, jak działa k-way SVM i wydaje się, że było to bardziej skomplikowane, niż myślałem.predict
metodę.Odpowiedzi:
Jednym ze sposobów poprawy jest zastosowanie podejścia „ważone wszystkie pary”, które podobno jest lepsze niż „jeden przeciw wszystkim”, a jednocześnie jest skalowalne.
Jeśli chodzi o istniejące pakiety,
glmnet
obsługuje (regularyzowany) wielomianowy logit, który może być używany jako klasyfikator wielu klas.źródło
glmnet
jest takżemultinomial
funkcja utraty. Zastanawiam się, czy tę funkcję straty można by zastosować w innych algorytmach w języku R, takich jakada
lubgbm
?ada
jest „zarezerwowany” dla konkretnej (wykładniczej) funkcji utraty, ale można by rozszerzyć inną poprawę oparta na metodzie do obsługi funkcji wielomianowej straty - np. patrz strona 360 Elementów uczenia statystycznego, aby uzyskać szczegółowe informacje na temat wielomklasowego GBM - drzewa binarne K są budowane dla każdej iteracji przyspieszającej, gdzie K jest liczbą klas (tylko jedno drzewo na iterację jest potrzebny w przypadku binarnym).