ベイズ統計モデリングとMCMC

Kanazawa.R #2

伊東宏樹

2024-11-23

内容

  • ベイズ統計モデリング

  • MCMC(マルコフ連鎖モンテカルロ)法

    • NIMBLEを使った例

    • Stanを使った例

統計モデルとは

  • 変数間の関係を何らかの確率分布を使って(パーツとして)記述して作成したモデル(模型)
  • 作成したモデルでシステムを説明したり、予測したりする

ベイズの定理

\[ P(Y \mid X) = \frac{P(X \mid Y)P(Y)}{P(X)} \]

\[ = \frac{P(X \mid Y)P(Y)}{P(X \mid Y)P(Y)+P(X \mid \overline{Y})P(\overline{Y})} \]

例題

1000人に1人がかかる病気があるとする。

検査をすると、この病気にかかっている場合には99%の確率で陽性となる。ただし、かかっていなくても5%の確率で誤って陽性になる。

ある人が検査を受けて陽性になった。このとき実際にこの人がこの病気にかかっている確率は何パーセントか。

\[ \frac{0.99 \times 0.001}{0.99 \times 0.001 + 0.05 \times 0.999} = 0.01943463 \]

実際にこの病気にかかっている確率はおよそ2%

ベイズ推定

  • 事前確率を、得られたデータで更新していく

  • 確率分布を推定するとき

    • 事前分布→データで更新→事後分布

MCMC(マルコフ連鎖モンテカルロ)法とは

  • 統計モデルのパラメータを推定する手法

  • MCMC = MC(マルコフ連鎖) + MC(モンテカルロ)

マルコフ連鎖

1期前の状態にのみ依存する確率変数列

例: ランダムウオーク

モンテカルロ法

乱数を使った推定法

例: 円周率を求める

MCMC

  • ベイズ統計モデルで、複雑な統計モデルのパラメータ推定に使われる

    • 解析的に解けない複雑なモデルのパラメータも推定できる
  • 乱数を使って、一定のアルゴリズム(Metropolis-Hastings法, Gibbsサンプリング, Hamiltonian Monte Carlo法など)により、事後分布からサンプリングしたと見なせるマルコフ連鎖を生成する

  • 短所: 計算に時間がかかる

MCMCのソフトウェア

など

いずれもRとは別のモデル記述言語で、モデルを記述する

NIMBLEを使った統計モデリングの例

データ

群ごとに切片が異なるが、群内では傾きはだいたい2くらいでどれも同程度

群を無視すると

lm(Y ~ X, data = df) |> summary()

Call:
lm(formula = Y ~ X, data = df)

Residuals:
    Min      1Q  Median      3Q     Max 
-3.3097 -0.8735  0.0257  0.8367  3.2818 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)   4.2243     0.4748   8.897 1.26e-15 ***
X             1.0163     0.1186   8.571 8.80e-15 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 1.245 on 158 degrees of freedom
Multiple R-squared:  0.3174,    Adjusted R-squared:  0.3131 
F-statistic: 73.47 on 1 and 158 DF,  p-value: 8.795e-15

傾きを過小評価してしまった

混合効果モデル

固定効果+変量効果

  • 固定効果: 説明変数による目的変数への効果

  • 変量効果(ランダム効果): 群の違いによる効果

    • 通常、変量効果は正規分布にしたがうとする

    • ベイズ統計モデリングでは、階層事前分布を設定→階層ベイズモデル

NIMBLEモデル

BUGS言語で統計モデルを記述

code <- nimbleCode({
  for (n in 1:N) {
    mu[n] <- alpha + beta * X[n] + epsilon[Gind[n]]
    Y[n] ~ dnorm(mu[n], tau[1])
  }
  for (g in 1:G) {
    epsilon[g] ~ dnorm(0, tau[2])
  }
  alpha ~ dnorm(0, 1e-4)
  beta ~ dnorm(0, 1e-4)
  for (i in 1:2) {
    tau[i] <- 1 / (sigma[i] * sigma[i])
    sigma[i] ~ dunif(0, 100)
  }
})

コンパイル・実行

G <- length(levels(df$Group))
out <- nimbleMCMC(code = code,
                  constants = list(N = nrow(df),
                                   G = G,
                                   Gind = as.numeric(df$Group)),
                  data = list(Y = df$Y, X = df$X),
                  inits = list(alpha = -2, beta = -2,
                               epsilon = rep(0, G),
                               sigma = c(4, 2)),
                  niter = 500, nburnin = 0,
                  samplesAsCodaMCMC = TRUE)

結果

betaのマルコフ連鎖の軌跡

burn-in

初期値の影響が残っている部分は捨てる

サンプリング

out <- nimbleMCMC(code = code,
                  constants = list(N = nrow(df), G = G,
                                   Gind = as.numeric(df$Group)),
                  data = list(Y = df$Y, X = df$X),
                  inits = function() {
                    list(alpha = runif(1, -2, 2),
                         beta = runif(1, -2, 2),
                    epsilon = runif(G, -2, 2),
                    sigma = runif(2, 0, 2))},
                  nchains = 3, niter = 12000, nburnin = 2000,
                  samplesAsCodaMCMC = TRUE)

traceplot (alpha)

マルコフ連鎖の軌跡プロット(codaパッケージのtraceplot関数を使用)

traceplot(out[, "alpha"])

traceplot (sigma[1])

traceplot(out[, "sigma[1]"])

これくらいよく混ざっているのが望ましい

R-hat

MCMC計算が収束したかどうかの指標値。1.1以下ならOKとする場合が多い。

gelman.diag(out)
Potential scale reduction factors:

         Point est. Upper C.I.
alpha          1.01       1.03
beta           1.01       1.01
sigma[1]       1.00       1.00
sigma[2]       1.01       1.02

Multivariate psrf

1.01

結果

結果の要約

summary(out)

Iterations = 1:10000
Thinning interval = 1 
Number of chains = 3 
Sample size per chain = 10000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

           Mean      SD  Naive SE Time-series SE
alpha    0.6708 0.77857 0.0044951      0.0597093
beta     1.9314 0.15523 0.0008962      0.0097383
sigma[1] 1.0107 0.05936 0.0003427      0.0007638
sigma[2] 1.4014 0.53740 0.0031027      0.0154120

2. Quantiles for each variable:

            2.5%    25%    50%   75% 97.5%
alpha    -0.9282 0.1844 0.6763 1.172 2.196
beta      1.6077 1.8340 1.9381 2.036 2.219
sigma[1]  0.9024 0.9694 1.0077 1.050 1.135
sigma[2]  0.7282 1.0377 1.2858 1.628 2.792

傾き(beta)の事後平均値は2に近い値に推定された

密度グラフ

densplot(out[, "beta"])

Stanを使った統計モデリング

RからStanを使う方法

  • rstanパッケージ
  • cmdstanrパッケージ

今回は前者を使用

Stanのモデル

Stanで記述した同等のモデル。各パラメータの事前分布は弱情報事前分布とした。

lme.stan
/*
  stan model for linear mixed effects model
*/

data {
  int<lower=0> N;
  int<lower=0> G;
  array[N] int<lower=1,upper=G> Gind;
  vector[N] X;
  vector[N] Y;
}

parameters {
  real alpha;
  real beta;
  vector[G] epsilon;
  vector<lower=0>[2] sigma;
}

transformed parameters {
   vector[N] mu;
   
   for (n in 1:N)
     mu[n] = alpha + beta * X[n] + epsilon[Gind[n]];
}

model {
  Y ~ normal(mu, sigma[1]);
  epsilon ~ normal(0, sigma[2]);
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 5);
}

実行・サンプリング

fit <- stan(file = file.path("model", "lme.stan"),
            data = list(N = nrow(df),
                        G = G,
                        Gind = as.numeric(df$Group),
                        X = df$X, Y = df$Y),
            pars = c("alpha", "beta", "sigma"),
            iter = 2000, warmup = 1000)

結果

各パラメータの事後分布の要約

                mean     se_mean         sd        2.5%         25%         50%
alpha      0.6412861 0.024448384 0.79665518  -1.0266203   0.1307284   0.6574317
beta       1.9381728 0.003752638 0.15381555   1.6384333   1.8347968   1.9395739
sigma[1]   1.0095801 0.001209584 0.05791051   0.9049794   0.9688155   1.0061584
sigma[2]   1.3836016 0.013909480 0.50315936   0.7382298   1.0453934   1.2905886
lp__     -86.5115631 0.071971256 2.58104006 -92.4959050 -87.9672471 -86.1859056
                75%      97.5%    n_eff     Rhat
alpha      1.192614   2.105521 1061.794 1.001746
beta       2.046297   2.238077 1680.069 1.003625
sigma[1]   1.046782   1.130690 2292.149 1.000252
sigma[2]   1.602286   2.563278 1308.547 1.002752
lp__     -84.662346 -82.458634 1286.090 1.002217

参考文献