LoginSignup
5
4

More than 5 years have passed since last update.

生存時間分析をStanで実行してみた

Posted at

はじめに

確率分布の定義

  • 生存時間の確率分布$f_t(t=1,\cdots\infty)$
  • 死亡関数(時刻$t$までに死亡している確率) $F_t=\sum_{i=1}^tf_i$
  • 生存関数(時刻$t$まで生存している確率) $S_t=1-F_t=\sum_{i=t+1}^\infty f_i$
  • 瞬間死亡率(ハザード) $h_t=\frac{f_t}{S_{t-1}}\quad(0\le h_t\le1, S_0=1)$
\begin{align}
f_t &= S_{t-1} h_t = F_t-F_{t-1}= S_{t-1}-S_t\\
S_t &= S_{t-1} (1-h_t)\\
S_t &= \prod_{i=1}^t(1-h_i),\quad
f_t = h_t\prod_{i=1}^{t-1}(1-h_i)
\end{align}
  • 時刻$t$でイベントが発生する確率は$f_t$となる。
  • 時刻$t$で打ち切りの場合、イベントは時刻$t+1$以降に発生するので、確率は$S_t$となる。

説明変数なしモデル

Stanプログラム

transformed parametersブロックで瞬間死亡率$h_t$から生存関数$S_t$と生存時間の確率分布$f_t$を計算する。

survival_0.stan
data {
  int<lower=1> N;
  int<lower=1> T;
  int<lower=1, upper=T> Time[N];
  int<lower=0, upper=1> Cens[N];
}

parameters {
  real<lower=0> sigma_log_hazard;
  vector<upper=0>[T] log_hazard;
}

transformed parameters {
  vector[T] log_S;
  vector[T] log_f;
  log_S[1] = log1m_exp(log_hazard[1]);
  log_f[1] = log_hazard[1];
  for (t in 2:T) {
    log_S[t] = log_S[t-1] + log1m_exp(log_hazard[t]);
    log_f[t] = log_S[t-1] + log_hazard[t];
  }
}

model {
  sigma_log_hazard ~ cauchy(0, 1);
  log_hazard ~ cauchy(0, 10);
  for (t in 2:T) {
    log_hazard[t] ~ normal(log_hazard[t-1], sigma_log_hazard);
  }
  for (n in 1:N) {
    if (Cens[n] == 1) {
      target += log_f[Time[n]];
    }
    else {
      target += log_S[Time[n]];
    }
  }
}

generated quantities {
  vector[T] hazard;
  vector[T] S;
  hazard = exp(log_hazard);
  S = exp(log_S);
}

分析対象データ

パッケージMASSに含まれているgehanという生存時間データを分析する。

library(ggplot2)
library(scales)
library(ggfortify)
library(gridExtra)
library(rstan)
library(survival)
library(MASS)
library(muhaz)

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

> head(gehan)
  pair time cens   treat
1    1    1    1 control
2    1   10    1    6-MP
3    2   22    1 control
4    2    7    1    6-MP
5    3    3    1 control
6    3   32    0    6-MP
> nrow(gehan)
[1] 42

Stanの実行

treatが6-MPとcontrolについて別々に実行する。

model_stan = stan_model("survival_0.stan")
tstart = proc.time()
treats = c("6-MP","control")
result = lapply(seq_along(treats), function(i){
  ge = gehan[gehan$treat == treats[i],]
  data_stan = list(N=nrow(ge), T=max(ge$time), Time=ge$time, Cens=ge$cens)
  stan_fit = sampling(model_stan, data=data_stan,
    chains=4, iter=4000, warmup=2000, thin=1, seed=3,
    control=list(adapt_delta=0.8, max_treedepth=10)
  )
  stan_fit
})
proc.time() - tstart # 23sec

実行結果の抽出

生存関数と瞬間死亡率の95%信用区間を抽出する。

my_extract = function(x,name){
  tb = summary(x, probs=c(0.025,0.975),pars=name)$summary
  df = data.frame(tb[,c("mean","2.5%","97.5%"), drop=F])
  colnames(df) = c("mean","lower","upper")
  df
}

df_S =lapply(result,function(x){
  df = my_extract(x,"S")
  df$time = 1:nrow(df)
  df
})

df_h =lapply(result,function(x){
  df = my_extract(x,"hazard")
  df$time = 1:nrow(df)
  df
})

グラフの作成

生存関数と瞬間死亡率をsurvfit関数による結果と比較する。
グラフの色を合わせて表示したいのでggplot2が使用する色コードを取得。

# ggplot2のデフォルトの色を知りたい
# https://qiita.com/hoxo_b/items/c569da6dbf568032e04a
ggColorHue <- function(n, l=65) {
  hues <- seq(15, 375, length=n+1)
  hcl(h=hues, l=l, c=100)[1:n]
}
ggcols = ggColorHue(n=2)

kphaz = kphaz.fit(time=gehan$time, status=gehan$cens, strata=gehan$treat)
df_kphaz = lapply(1:2,function(i){
  f = (kphaz$strata==i)
  data.frame(time=kphaz$time[f],haz=kphaz$haz[f])
})

fit_sv = survfit(Surv(time, cens) ~ treat, data = gehan)
summary(fit_sv)

plot_S = autoplot(fit_sv) +
  geom_line(data=df_S[[1]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[1]) +
  geom_ribbon(data=df_S[[1]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[1],inherit.aes=F) +
  geom_line(data=df_S[[2]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[2]) +
  geom_ribbon(data=df_S[[2]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[2],inherit.aes=F)

plot_haz = ggplot() +
  geom_line(data=df_kphaz[[1]],aes(x=time,y=haz),size=0.5,linetype=1,color=ggcols[1]) +
  geom_line(data=df_kphaz[[2]],aes(x=time,y=haz),size=0.5,linetype=1,color=ggcols[2]) +
  geom_line(data=df_h[[1]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[1]) +
  geom_ribbon(data=df_h[[1]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[1]) +
  geom_line(data=df_h[[2]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[2]) +
  geom_ribbon(data=df_h[[2]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[2]) +
  scale_y_continuous(labels = scales::percent)

grid.arrange(plot_haz,plot_S,ncol=2,widths = c(0.8,1))

実線がsurvfit関数で推定した結果、破線がStanで推定した結果。
fit_sv_0.png

説明変数ありモデル

  • 対象とする説明変数は0または1の値を取るものとする。
  • Cox比例ハザードモデルを適用する。説明変数の値=1の効果を$\beta$とすると、$\beta$は時刻に依存せず、瞬間死亡率を$\beta$倍にすると仮定する。
  • データが長期間に渡る場合、全ての時刻で瞬間死亡率を推定すると計算時間がかかるため、イベント発生時刻・打ち切り時刻の間では、瞬間死亡率は一定と仮定する。
  • $S_l=S_k(1-h_l)^{l-k}, f_l=S_k(1-h_l)^{l-k-1}h_l$
  • ここで$h_l$は、時刻$k+1$から$l$までの瞬間死亡率とする。
  • 計算時間を短縮するため説明変数の値の組合せでユニークなものを予めグループとしてまとめておく。データには対応するグループの番号を指定する。

Stanプログラム

transformed parametersブロックで瞬間死亡率$h_t$からグループ毎の生存関数$S_t$と生存時間の確率分布$f_t$を計算する。
generated quantitiesブロックでベースラインを計算する。

survival_4.stan
data {
  int<lower=1> N;
  int<lower=1> T;
  int<lower=1> G;
  int<lower=1> K;
  int<lower=1, upper=T> TimeID[N];
  int<lower=0, upper=1> Cens[N];
  int<lower=0, upper=G> GroupID[N];
  int<lower=1> Time[T];
  matrix<lower=0, upper=1>[G, K] Group;
}

parameters {
  real<lower=0> sigma_log_hazard;
  vector<upper=0>[T] log_hazard;
  vector[K] beta;
}

transformed parameters {
  matrix[G, T] log_S;
  matrix[G, T] log_f;
  for (g in 1:G) {
    vector[T] new_hazard = log_hazard + dot_product(Group[g], beta);
    log_S[g, 1] = log1m_exp(new_hazard[1]) * (Time[1]);
    log_f[g, 1] = log1m_exp(new_hazard[1]) * (Time[1] - 1) + new_hazard[1];
    for (t in 2:T) {
      log_S[g, t] = log_S[g, t-1] + log1m_exp(new_hazard[t]) * (Time[t] - Time[t-1]);
      log_f[g, t] = log_S[g, t-1] + log1m_exp(new_hazard[t]) * (Time[t] - Time[t-1] - 1) + new_hazard[t];
    }
  }
}

model {
  sigma_log_hazard ~ cauchy(0, 1);
  log_hazard ~ cauchy(0, 10);
  beta ~ cauchy(0, 10);
  for (t in 2:T) {
    log_hazard[t] ~ normal(log_hazard[t-1], sigma_log_hazard);
  }
  for (n in 1:N) {
    if (Cens[n] == 1) {
      target += log_f[GroupID[n], TimeID[n]];
    }
    else {
      target += log_S[GroupID[n], TimeID[n]];
    }
  }
}

generated quantities{
  vector[T] hazard;
  vector[T] log_S0;
  vector[T] S0;
  vector[K] exp_beta;
  hazard = exp(log_hazard);
  log_S0[1] = log1m_exp(log_hazard[1]) * (Time[1]);
  for (t in 2:T) {
    log_S0[t] = log_S0[t-1] + log1m_exp(log_hazard[t]) * (Time[t] - Time[t-1]);
  }
  S0 = exp(log_S0);
  exp_beta = exp(beta);
}

分析対象データ

パッケージsurvivalに含まれているkidneyという生存時間データを分析する。
説明変数としてsexとdiseaseを使用する。

> head(kidney)
  id time status age sex disease frail
1  1    8      1  28   1   Other   2.3
2  1   16      1  28   1   Other   2.3
3  2   23      1  48   2      GN   1.9
4  2   13      0  48   2      GN   1.9
5  3   22      1  32   1   Other   1.2
6  3   28      1  32   1   Other   1.2
> nrow(kidney)
[1] 76

時刻情報とグループ情報の抽出

df_kidney = kidney
df_kidney$sex = c("M","F")[kidney$sex]
df_kidney$sex = as.factor(df_kidney$sex)

dum = cbind(
  model.matrix(~df_kidney$sex)[,-1,drop=F],
  model.matrix(~df_kidney$disease)[,-1]
)
colnames(dum) = c("sexM","diseaseGN","diseaseAN","diseasePKD")

times = sort(unique(df_kidney$time))
time_id = match(df_kidney$time,times)

group = unique(dum)
rownames(group) = 1:nrow(group)
group_char = apply(group,1,paste,collapse="")
group_id = match(apply(dum,1,paste,collapse=""),group_char)
> times
 [1]   2   4   5   6   7   8   9  12  13  15  16  17  22  23  24  25  26  27  28  30  34  38  39
[24]  40  43  46  53  54  58  63  66  70  78  96 108 113 114 119 130 132 141 149 152 154 156 159
[47] 177 185 190 196 201 245 292 318 333 402 447 511 536 562
> time_id
 [1]  6 11 14  9 13 19 57 54 20  8 15 52  5  7 58 20 27 50 10 44  5 55 41  6 34 22 42 32 59 16 12
[32]  2 48 47 53 37 13 46 10 35 43 60 56 15  9 31 23 26  8 24 36 51 40 45 21 20  1 16 39 17 18 29
[63]  3 25 43 20 49  3 38  6 28 11  4 33 30  6
> group
  sexM diseaseGN diseaseAN diseasePKD
1    1         0         0          0
2    0         1         0          0
3    0         0         0          0
4    1         1         0          0
5    0         0         1          0
6    1         0         1          0
7    1         0         0          1
8    0         0         0          1
> group_id
 [1] 1 1 2 2 1 1 3 3 1 1 3 3 4 4 2 2 5 5 4 4 5 5 3 3 5 5 5 5 3 3 6 6 3 3 3 3 2 2 3 3 7 7 3 3 5 5 5
[48] 5 6 6 5 5 2 2 5 5 4 4 2 2 5 5 5 5 8 8 2 2 3 3 3 3 8 8 7 7

Stanの実行

model_stan = stan_model("survival_4.stan")

data_stan = list(
  N = nrow(df_kidney),
  T = length(times),
  G = nrow(group),
  K = ncol(group),
  TimeID = time_id,
  Cens = df_kidney$status,
  GroupID = group_id,
  Time = times,
  Group = group
)

tstart = proc.time()
stan_fit = sampling(model_stan,data=data_stan,
  chains=4, iter=4000, warmup=2000, thin=1, seed=3,
  control=list(adapt_delta=0.8,max_treedepth=10)
)
proc.time() - tstart # 122sec

実行結果の抽出

生存関数と瞬間死亡率、説明変数の効果$\beta$の95%信用区間を抽出する。
瞬間死亡率の時刻情報を推定した区間の中央に設定する。

my_extract = function(x,name){
  tb = summary(x, probs=c(0.025,0.975),pars=name)$summary
  df = data.frame(tb[,c("mean","2.5%","97.5%"), drop=F])
  colnames(df) = c("mean","lower","upper")
  df
}

df_S0 = my_extract(stan_fit,"S0")
df_S0$time = times

df_h = my_extract(stan_fit,"hazard")
df_h$time = (function(){
  n = length(times)
  (times + c(0,times[-n]) + 1)/2
})()

df_exp_beta = my_extract(stan_fit,"exp_beta")
rownames(df_exp_beta) = c("sexM","diseaseGN","diseaseAN","diseasePKD")

beta_hist = extract(stan_fit,pars="exp_beta")$exp_beta
beta_hist = data.frame(beta_hist)
colnames(beta_hist) = c("sexM","diseaseGN","diseaseAN","diseasePKD")

グラフの作成

推定された説明変数の効果$\beta$の分布。

pvline = geom_vline(xintercept=1,color="red")
plot1 = ggplot(beta_hist,aes(x=sexM)) + geom_histogram(bins=30) + pvline
plot2 = ggplot(beta_hist,aes(x=diseaseGN)) + geom_histogram(bins=30) + pvline
plot3 = ggplot(beta_hist,aes(x=diseaseAN)) + geom_histogram(bins=30) + pvline
plot4 = ggplot(beta_hist,aes(x=diseasePKD)) + geom_histogram(bins=30) + pvline

grid.arrange(plot1,plot2,plot3,plot4,ncol=2,nrow=2)

fit_cox1.png

coxph関数による効果$\beta$の推定値との比較。

kidney.cox = coxph(Surv(time, status) ~ sex+disease, data=df_kidney)
summary(kidney.cox)
df_exp_beta
> kidney.cox = coxph(Surv(time, status) ~ sex+disease, data=df_kidney)
> summary(kidney.cox)
Call:
coxph(formula = Surv(time, status) ~ sex + disease, data = df_kidney)

  n= 76, number of events= 58 

              coef exp(coef) se(coef)      z Pr(>|z|)    
sexM        1.4774    4.3815   0.3569  4.140 3.48e-05 ***
diseaseGN   0.1392    1.1494   0.3635  0.383   0.7017    
diseaseAN   0.4132    1.5116   0.3360  1.230   0.2188    
diseasePKD -1.3671    0.2549   0.5889 -2.321   0.0203 *  
---
Signif. codes:  0 *** 0.001 ** 0.01 * 0.05 . 0.1   1

           exp(coef) exp(-coef) lower .95 upper .95
sexM          4.3815     0.2282   2.17689    8.8188
diseaseGN     1.1494     0.8700   0.56368    2.3437
diseaseAN     1.5116     0.6616   0.78245    2.9202
diseasePKD    0.2549     3.9238   0.08035    0.8084

Concordance= 0.696  (se = 0.045 )
Rsquare= 0.206   (max possible= 0.993 )
Likelihood ratio test= 17.56  on 4 df,   p=0.001501
Wald test            = 19.77  on 4 df,   p=0.0005533
Score (logrank) test = 19.97  on 4 df,   p=0.0005069

> df_exp_beta
                mean      lower     upper
M          4.3978555 2.06594992 8.0631262
diseaseGN  0.9693798 0.45369567 1.7495952
diseaseAN  1.3627789 0.68891947 2.4198428
diseasePKD 0.2758943 0.08059133 0.6748802
df_cox = data.frame(summary(kidney.cox)$conf.int)
colnames(df_cox) = c("mean","inv_mean","lower","upper")
df_cox$x = factor(1:4,levels=1:4,labels=c("sexM","diseaseGN","diseaseAN","diseasePKD"))
df_exp_beta$x = (1:4)+0.1

ggplot() + 
  geom_pointrange(data=df_cox,aes(x=x,y=mean,ymin=lower,ymax=upper),size=0.5) +
  geom_pointrange(data=df_exp_beta,aes(x=x,y=mean,ymin=lower,ymax=upper),size=0.5,linetype=2) +
  labs(x="",y="beta")

実線がcoxph関数で推定した結果、破線がStanで推定した結果。
fit_cox2.png

推定された生存関数と瞬間死亡率のグラフ。

plot_S = ggplot() +
  geom_line(data=df_S0,aes(x=time,y=mean),size=0.5,linetype=2) +
  geom_ribbon(data=df_S0,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F) +
  labs(y="surv") +
  scale_y_continuous(labels = scales::percent)

plot_h = ggplot() +
  geom_line(data=df_h,aes(x=time,y=mean),size=0.5,linetype=2) +
  geom_ribbon(data=df_h,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F) +
  labs(y="hazard") +
  scale_y_continuous(labels = scales::percent)

grid.arrange(plot_h,plot_S,ncol=2)

fit_cox0.png

coxph関数によるベースラインとの比較。

kidney.fit = survfit(kidney.cox, newdata = data.frame(sex="F", disease="Other"))

# rough estimation of hazard from survival function
df_kphaz = (function(){
  n = length(kidney.fit$surv)
  times = (kidney.fit$time[2:n] + kidney.fit$time[1:(n-1)])/2
  interval = (kidney.fit$time[2:n] - kidney.fit$time[1:(n-1)])
  ratio = kidney.fit$surv[2:n] / kidney.fit$surv[1:(n-1)]
  data.frame(
    time = times,
    hazard = (1 - ratio) / interval
  )
})()

plot_S = autoplot(kidney.fit) +
  geom_line(data=df_S0,aes(x=time,y=mean),size=0.5,linetype=2) +
  geom_ribbon(data=df_S0,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F)

plot_h = ggplot() +
  geom_line(data=df_kphaz,aes(x=time,y=hazard),size=0.5) +
  geom_line(data=df_h,aes(x=time,y=mean),size=0.5,linetype=2) +
  geom_ribbon(data=df_h,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F) +
  labs(y="hazard") +
  scale_y_continuous(labels = scales::percent)

grid.arrange(plot_h,plot_S,ncol=2)

実線がcoxph関数で推定した結果、破線がStanで推定した結果。
fit_cox3.png

おわりに

  • geom_pointrange関数でx座標を上手くずらす方法を見つけられなかった。
  • stan_ac関数で確認すると、betaの自己相関は急速に減少しているが、sigma_log_hazardとlp__については、なかなか減少していない。Rhatは全て1.1以下となっている。
  • Stanの実行パラメータはデフォルトのままなので調整の余地はある。
v = c("sigma_log_hazard","beta","lp__")
stan_ac(stan_fit,pars=v)

fit_cox4.png

print(stan_fit,digits=3,pars=v,probs=c(0.025,0.5,0.975))
Inference for Stan model: survival_4.
4 chains, each with iter=4000; warmup=2000; thin=1; 
post-warmup draws per chain=2000, total post-warmup draws=8000.

                     mean se_mean     sd     2.5%      50%    97.5% n_eff  Rhat
sigma_log_hazard    0.203   0.012  0.109    0.053    0.185    0.452    89 1.059
beta[1]             1.422   0.005  0.345    0.726    1.428    2.087  5849 1.001
beta[2]            -0.090   0.006  0.346   -0.790   -0.081    0.559  3299 1.001
beta[3]             0.257   0.007  0.324   -0.373    0.257    0.884  2046 1.002
beta[4]            -1.431   0.008  0.545   -2.518   -1.414   -0.393  4618 1.000
lp__             -174.200   2.727 31.170 -226.916 -177.323 -106.646   131 1.071
5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4