TMBによる最尤推定

Tokyo.R 115

伊東宏樹

2024-10-19

自己紹介

Kanazawa.R #2

TMB

データ

モデル

poismix.cpp
#include <TMB.hpp>

template<class Type>
Type objective_function<Type>::operator() ()
{
  DATA_VECTOR(Y);            // Objective variable
  DATA_VECTOR(X);            // Explanatory variable
  DATA_IVECTOR(G);           // Group index
  PARAMETER(alpha);          // Intercept
  PARAMETER(beta);           // Slope
  PARAMETER_VECTOR(epsilon); // Random effect
  PARAMETER(log_sigma);      // log SD of random effect
  Type nll = 0;              // negative log likelihood

  nll += -sum(dnorm(epsilon, Type(0.0), exp(log_sigma), true));
  for (int i = 0; i < Y.size(); i++) {
    Type lambda = exp(alpha + beta * X(i) + epsilon(G(i)));
    nll += -dpois(Y(i), lambda, true);
  }
  return nll;
}

実行

コンパイル(compile)→ロード(dyn.load)→微分関数を作成(MakeADFun, random引数に変量効果)→最適化(nlminb)

model_name <- "poismix"
file.path("models", paste(model_name, "cpp", sep = ".")) |>
  compile()
file.path("models", dynlib(model_name)) |>
  dyn.load()
data <- list(Y = Y, X = X, G = group - 1)
parameters <- list(alpha = 0, beta = 1,
                   epsilon = rep(0, N_group), log_sigma = 0)
obj <- MakeADFun(data, parameters, DLL = model_name,
                 random = "epsilon")
opt <- nlminb(obj$par, obj$fn, obj$gr)

結果

print(opt)
$par
     alpha       beta  log_sigma 
-1.9741825  0.5644704 -0.1077009 

$objective
[1] 182.9479

$convergence
[1] 0

$iterations
[1] 14

$evaluations
function gradient 
      19       15 

$message
[1] "relative convergence (4)"

glmmTMB

  • 内部でTMBを利用
  • glmmの書式でモデル式を書ける(C++を知らなくても大丈夫)
library(glmmTMB)
fit <- glmmTMB(Y ~ X + (1|group), data = df, family = poisson())

結果

summary(fit)
 Family: poisson  ( log )
Formula:          Y ~ X + (1 | group)
Data: df

     AIC      BIC   logLik deviance df.resid 
   371.9    381.1   -182.9    365.9      157 

Random effects:

Conditional model:
 Groups Name        Variance Std.Dev.
 group  (Intercept) 0.8062   0.8979  
Number of obs: 160, groups:  group, 8

Conditional model:
            Estimate Std. Error z value Pr(>|z|)    
(Intercept) -1.97418    0.40252  -4.905 9.36e-07 ***
X            0.56447    0.06469   8.726  < 2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

tmbstan

  • TMBのオブジェクトを利用して、StanでMCMCによる推定
  • 局所最適を回避できる
library(tmbstan)
stanfit <- tmbstan(obj)

結果

print(stanfit, pars = c("alpha", "beta", "log_sigma"))
Inference for Stan model: poismix.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
alpha     -1.98    0.02 0.46 -2.92 -2.28 -1.96 -1.67 -1.12   888    1
beta       0.57    0.00 0.06  0.44  0.52  0.57  0.61  0.70  1989    1
log_sigma  0.03    0.01 0.32 -0.53 -0.19  0.01  0.23  0.72  1718    1

Samples were drawn using NUTS(diag_e) at Fri Oct 18 13:11:48 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).