Zastosowanie stochastycznego wnioskowania wariacyjnego do Bayesian Mixture of Gaussian

9

Próbuję zaimplementować model mieszanki Gaussa z stochastycznym wnioskiem wariacyjnym, zgodnie z tym artykułem .

wprowadź opis zdjęcia tutaj

To jest pgm mieszanki Gaussa.

Według artykułu, pełny algorytm stochastycznego wnioskowania wariacyjnego to: wprowadź opis zdjęcia tutaj

I nadal jestem bardzo zdezorientowany co do metody skalowania go do GMM.

Po pierwsze, myślałem, że lokalny parametr wariacyjny jest po prostu qza inne są parametrami globalnymi. Popraw mnie, jeśli się myliłem. Co oznacza krok 6 as though Xi is replicated by N times? Co mam zrobić, aby to osiągnąć?

Czy możesz mi w tym pomóc? Z góry dziękuję!

użytkownik5779223
źródło
Mówi, że zamiast korzystać z całego zestawu danych, próbkuj jeden punkt danych i udawaj, że masz Npunkty danych tego samego rozmiaru. W wielu przypadkach będzie to równoznaczne z pomnożeniem oczekiwania przez jeden punkt danych przezN.
Daeyoung Lim
@DaeyoungLim Dziękujemy za odpowiedź! Rozumiem teraz, co masz na myśli, ale nadal nie rozumiem, które statystyki powinny być aktualizowane lokalnie, a które globalnie. Na przykład, oto implementacja mieszanki Gaussa, czy możesz mi powiedzieć, jak skalować do svi? Jestem trochę zagubiony. Wielkie dzięki!
user5779223,
Nie przeczytałem całego kodu, ale jeśli masz do czynienia z modelem mieszanki Gaussa, zmienne wskaźnikowe składnika mieszanki powinny być zmiennymi lokalnymi, ponieważ każda z nich jest powiązana tylko z jedną obserwacją. Tak więc ukryte zmienne składnika mieszaniny, które następują po rozkładzie Multinoulli (znanym również jako rozkład kategoryczny w ML), sązi,i=1,,Nw twoim opisie powyżej.
Daeyoung Lim,
@DaeyoungLim Tak, rozumiem, co powiedziałeś do tej pory. Zatem dla rozkładu wariacyjnego q (Z) q (\ pi, \ mu, \ lambda) q (Z) powinna być zmienną lokalną. Ale istnieje wiele parametrów związanych z q (Z). Z drugiej strony istnieje wiele parametrów związanych z q (\ pi, \ mu, \ lambda). I nie wiem, jak je odpowiednio zaktualizować.
user5779223,
Należy użyć założenia pola średniego, aby uzyskać optymalne rozkłady wariacyjne dla parametrów wariacyjnych. Oto referencja: maths.usyd.edu.au/u/jormerod/JTOpapers/Ormerod10.pdf
Daeyoung Lim

Odpowiedzi:

1

Po pierwsze, kilka notatek, które pomagają mi zrozumieć tekst SVI:

  • Przy obliczaniu wartości pośredniej parametru wariacyjnego parametrów globalnych próbkujemy jeden punkt danych i udajemy, że nasz cały zestaw danych wielkości N był ten jeden punkt, N czasy.
  • ηg jest naturalnym parametrem dla pełnego warunku zmiennej globalnej β. Notacja służy do podkreślenia, że ​​jest to funkcja zmiennych warunkowych, w tym obserwowanych danych.

W mieszaninie k Gaussianie, nasze parametry globalne to parametry średnie i precyzyjne (wariancja odwrotna) μk,τkparametry dla każdego. To jest,ηg jest naturalnym parametrem tego rozkładu, normalną gamma formy

μ,τN(μ|γ,τ(2α1)Ga(τ|α,β)

z η0=2α1, η1=γ(2α1) i η2=2β+γ2(2α1). (Bernardo i Smith, teoria bayesowska ; zwróć uwagę, że różni się ona nieco od czteroparametrowej normalnej gamma, którą zwykle widzisz .) Użyjemya,b,m odnosić się do parametrów wariacyjnych dla α,β,μ

Pełny warunek μk,τk jest normalną gamma z parametrami η˙+Nzn,k, Nzn,kxN, Nzn,kxn2, gdzie η˙jest przeorem. (Thezn,ktam może być również mylące; ma sens, zaczynając odexpln(p)) sztuczka zastosowana do Np(xn|zn,α,β,γ)=NK(p(xn|αk,βk,γk))zn,ki kończąc na sporej ilości algebry pozostawionej czytelnikowi).

Dzięki temu możemy wykonać krok (5) pseudokodu SVI za pomocą:

ϕn,kexp(ln(π)+Eqln(p(xn|αk,βk,γk))=exp(ln(π)+Eq[μkτk,τ2x,x2μ2τlnτ2)]

Aktualizacja parametrów globalnych jest łatwiejsza, ponieważ każdy parametr odpowiada liczbie danych lub jednej z jego wystarczających statystyk:

λ^=η˙+Nϕn1,x,x2

Oto, jak wygląda minimalne prawdopodobieństwo danych w wielu iteracjach, gdy są szkolone na bardzo sztucznych, łatwych do oddzielenia danych (kod poniżej). Pierwszy wykres pokazuje prawdopodobieństwo przy początkowych, losowych parametrach wariacyjnych i0iteracje; każde następne następuje po następnej potędze dwóch iteracji. W kodziea,b,m odnoszą się do parametrów wariacyjnych dla α,β,μ.

wprowadź opis zdjęcia tutaj

wprowadź opis zdjęcia tutaj

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 12 12:49:15 2018

@author: SeanEaster
"""

import numpy as np
from matplotlib import pylab as plt
from scipy.stats import t
from scipy.special import digamma 

# These are priors for mu, alpha and beta

def calc_rho(t, delay=16,forgetting=1.):
    return np.power(t + delay, -forgetting)

m_prior, alpha_prior, beta_prior = 0., 1., 1.
eta_0 = 2 * alpha_prior - 1
eta_1 = m_prior * (2 * alpha_prior - 1)
eta_2 = 2 *  beta_prior + np.power(m_prior, 2.) * (2 * alpha_prior - 1)

k = 3

eta_shape = (k,3)
eta_prior = np.ones(eta_shape)
eta_prior[:,0] = eta_0
eta_prior[:,1] = eta_1
eta_prior[:,2] = eta_2

np.random.seed(123) 
size = 1000
dummy_data = np.concatenate((
        np.random.normal(-1., scale=.25, size=size),
        np.random.normal(0.,  scale=.25,size=size),
        np.random.normal(1., scale=.25, size=size)
        ))
N = len(dummy_data)
S = 1

# randomly init global params
alpha = np.random.gamma(3., scale=1./3., size=k)
m = np.random.normal(scale=1, size=k)
beta = np.random.gamma(3., scale=1./3., size=k)

eta = np.zeros(eta_shape)
eta[:,0] = 2 * alpha - 1
eta[:,1] = m * eta[:,0]
eta[:,2] = 2. * beta + np.power(m, 2.) * eta[:,0]


phi = np.random.dirichlet(np.ones(k) / k, size = dummy_data.shape[0])

nrows, ncols = 4, 5
total_plots = nrows * ncols
total_iters = np.power(2, total_plots - 1)
iter_idx = 0

x = np.linspace(dummy_data.min(), dummy_data.max(), num=200)

while iter_idx < total_iters:

    if np.log2(iter_idx + 1) % 1 == 0:

        alpha = 0.5 * (eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2.) / eta[:,0])
        m = eta[:,1] / eta[:,0]
        idx = int(np.log2(iter_idx + 1)) + 1

        f = plt.subplot(nrows, ncols, idx)
        s = np.zeros(x.shape)
        for _ in range(k):
            y = t.pdf(x, alpha[_], m[_], 2 * beta[_] / (2 * alpha[_] - 1))
            s += y
            plt.plot(x, y)
        plt.plot(x, s)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)

    # randomly sample data point, update parameters
    interm_eta = np.zeros(eta_shape)
    for _ in range(S):
        datum = np.random.choice(dummy_data, 1)

        # mean params for ease of calculating expectations
        alpha = 0.5 * ( eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2) / eta[:,0])
        m = eta[:,1] / eta[:,0]

        exp_mu = m
        exp_tau = alpha / beta 
        exp_tau_m_sq = 1. / (2 * alpha - 1) + np.power(m, 2.) * alpha / beta
        exp_log_tau = digamma(alpha) - np.log(beta)


        like_term = datum * (exp_mu * exp_tau) - np.power(datum, 2.) * exp_tau / 2 \
            - (0.5 * exp_tau_m_sq - 0.5 * exp_log_tau)
        log_phi = np.log(1. / k) + like_term
        phi = np.exp(log_phi)
        phi = phi / phi.sum()

        interm_eta[:, 0] += phi
        interm_eta[:, 1] += phi * datum
        interm_eta[:, 2] += phi * np.power(datum, 2.)

    interm_eta = interm_eta * N / S
    interm_eta += eta_prior

    rho = calc_rho(iter_idx + 1)

    eta = (1 - rho) * eta + rho * interm_eta

    iter_idx += 1
Sean Easter
źródło