Posted at

生存時間分析をStanで実行してみた

More than 1 year has passed since last update.


はじめに


確率分布の定義


  • 生存時間の確率分布$f_t(t=1,\cdots\infty)$

  • 死亡関数(時刻$t$までに死亡している確率) $F_t=\sum_{i=1}^tf_i$

  • 生存関数(時刻$t$まで生存している確率) $S_t=1-F_t=\sum_{i=t+1}^\infty f_i$

  • 瞬間死亡率(ハザード) $h_t=\frac{f_t}{S_{t-1}}\quad(0\le h_t\le1, S_0=1)$

\begin{align}

f_t &= S_{t-1} h_t = F_t-F_{t-1}= S_{t-1}-S_t\\
S_t &= S_{t-1} (1-h_t)\\
S_t &= \prod_{i=1}^t(1-h_i),\quad
f_t = h_t\prod_{i=1}^{t-1}(1-h_i)
\end{align}


  • 時刻$t$でイベントが発生する確率は$f_t$となる。

  • 時刻$t$で打ち切りの場合、イベントは時刻$t+1$以降に発生するので、確率は$S_t$となる。


説明変数なしモデル


Stanプログラム

transformed parametersブロックで瞬間死亡率$h_t$から生存関数$S_t$と生存時間の確率分布$f_t$を計算する。


survival_0.stan

data {

int<lower=1> N;
int<lower=1> T;
int<lower=1, upper=T> Time[N];
int<lower=0, upper=1> Cens[N];
}

parameters {
real<lower=0> sigma_log_hazard;
vector<upper=0>[T] log_hazard;
}

transformed parameters {
vector[T] log_S;
vector[T] log_f;
log_S[1] = log1m_exp(log_hazard[1]);
log_f[1] = log_hazard[1];
for (t in 2:T) {
log_S[t] = log_S[t-1] + log1m_exp(log_hazard[t]);
log_f[t] = log_S[t-1] + log_hazard[t];
}
}

model {
sigma_log_hazard ~ cauchy(0, 1);
log_hazard ~ cauchy(0, 10);
for (t in 2:T) {
log_hazard[t] ~ normal(log_hazard[t-1], sigma_log_hazard);
}
for (n in 1:N) {
if (Cens[n] == 1) {
target += log_f[Time[n]];
}
else {
target += log_S[Time[n]];
}
}
}

generated quantities {
vector[T] hazard;
vector[T] S;
hazard = exp(log_hazard);
S = exp(log_S);
}



分析対象データ

パッケージMASSに含まれているgehanという生存時間データを分析する。

library(ggplot2)

library(scales)
library(ggfortify)
library(gridExtra)
library(rstan)
library(survival)
library(MASS)
library(muhaz)

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

> head(gehan)
pair time cens treat
1 1 1 1 control
2 1 10 1 6-MP
3 2 22 1 control
4 2 7 1 6-MP
5 3 3 1 control
6 3 32 0 6-MP
> nrow(gehan)
[1] 42


Stanの実行

treatが6-MPとcontrolについて別々に実行する。

model_stan = stan_model("survival_0.stan")

tstart = proc.time()
treats = c("6-MP","control")
result = lapply(seq_along(treats), function(i){
ge = gehan[gehan$treat == treats[i],]
data_stan = list(N=nrow(ge), T=max(ge$time), Time=ge$time, Cens=ge$cens)
stan_fit = sampling(model_stan, data=data_stan,
chains=4, iter=4000, warmup=2000, thin=1, seed=3,
control=list(adapt_delta=0.8, max_treedepth=10)
)
stan_fit
})
proc.time() - tstart # 23sec


実行結果の抽出

生存関数と瞬間死亡率の95%信用区間を抽出する。

my_extract = function(x,name){

tb = summary(x, probs=c(0.025,0.975),pars=name)$summary
df = data.frame(tb[,c("mean","2.5%","97.5%"), drop=F])
colnames(df) = c("mean","lower","upper")
df
}

df_S =lapply(result,function(x){
df = my_extract(x,"S")
df$time = 1:nrow(df)
df
})

df_h =lapply(result,function(x){
df = my_extract(x,"hazard")
df$time = 1:nrow(df)
df
})


グラフの作成

生存関数と瞬間死亡率をsurvfit関数による結果と比較する。

グラフの色を合わせて表示したいのでggplot2が使用する色コードを取得。

# ggplot2のデフォルトの色を知りたい

# https://qiita.com/hoxo_b/items/c569da6dbf568032e04a
ggColorHue <- function(n, l=65) {
hues <- seq(15, 375, length=n+1)
hcl(h=hues, l=l, c=100)[1:n]
}
ggcols = ggColorHue(n=2)

kphaz = kphaz.fit(time=gehan$time, status=gehan$cens, strata=gehan$treat)
df_kphaz = lapply(1:2,function(i){
f = (kphaz$strata==i)
data.frame(time=kphaz$time[f],haz=kphaz$haz[f])
})

fit_sv = survfit(Surv(time, cens) ~ treat, data = gehan)
summary(fit_sv)

plot_S = autoplot(fit_sv) +
geom_line(data=df_S[[1]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[1]) +
geom_ribbon(data=df_S[[1]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[1],inherit.aes=F) +
geom_line(data=df_S[[2]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[2]) +
geom_ribbon(data=df_S[[2]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[2],inherit.aes=F)

plot_haz = ggplot() +
geom_line(data=df_kphaz[[1]],aes(x=time,y=haz),size=0.5,linetype=1,color=ggcols[1]) +
geom_line(data=df_kphaz[[2]],aes(x=time,y=haz),size=0.5,linetype=1,color=ggcols[2]) +
geom_line(data=df_h[[1]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[1]) +
geom_ribbon(data=df_h[[1]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[1]) +
geom_line(data=df_h[[2]],aes(x=time,y=mean),size=0.5,linetype=2,color=ggcols[2]) +
geom_ribbon(data=df_h[[2]],aes(x=time,ymin=lower,ymax=upper),alpha=0.3,fill=ggcols[2]) +
scale_y_continuous(labels = scales::percent)

grid.arrange(plot_haz,plot_S,ncol=2,widths = c(0.8,1))

実線がsurvfit関数で推定した結果、破線がStanで推定した結果。

fit_sv_0.png


説明変数ありモデル


  • 対象とする説明変数は0または1の値を取るものとする。

  • Cox比例ハザードモデルを適用する。説明変数の値=1の効果を$\beta$とすると、$\beta$は時刻に依存せず、瞬間死亡率を$\beta$倍にすると仮定する。

  • データが長期間に渡る場合、全ての時刻で瞬間死亡率を推定すると計算時間がかかるため、イベント発生時刻・打ち切り時刻の間では、瞬間死亡率は一定と仮定する。

  • $S_l=S_k(1-h_l)^{l-k}, f_l=S_k(1-h_l)^{l-k-1}h_l$

  • ここで$h_l$は、時刻$k+1$から$l$までの瞬間死亡率とする。

  • 計算時間を短縮するため説明変数の値の組合せでユニークなものを予めグループとしてまとめておく。データには対応するグループの番号を指定する。


Stanプログラム

transformed parametersブロックで瞬間死亡率$h_t$からグループ毎の生存関数$S_t$と生存時間の確率分布$f_t$を計算する。

generated quantitiesブロックでベースラインを計算する。


survival_4.stan

data {

int<lower=1> N;
int<lower=1> T;
int<lower=1> G;
int<lower=1> K;
int<lower=1, upper=T> TimeID[N];
int<lower=0, upper=1> Cens[N];
int<lower=0, upper=G> GroupID[N];
int<lower=1> Time[T];
matrix<lower=0, upper=1>[G, K] Group;
}

parameters {
real<lower=0> sigma_log_hazard;
vector<upper=0>[T] log_hazard;
vector[K] beta;
}

transformed parameters {
matrix[G, T] log_S;
matrix[G, T] log_f;
for (g in 1:G) {
vector[T] new_hazard = log_hazard + dot_product(Group[g], beta);
log_S[g, 1] = log1m_exp(new_hazard[1]) * (Time[1]);
log_f[g, 1] = log1m_exp(new_hazard[1]) * (Time[1] - 1) + new_hazard[1];
for (t in 2:T) {
log_S[g, t] = log_S[g, t-1] + log1m_exp(new_hazard[t]) * (Time[t] - Time[t-1]);
log_f[g, t] = log_S[g, t-1] + log1m_exp(new_hazard[t]) * (Time[t] - Time[t-1] - 1) + new_hazard[t];
}
}
}

model {
sigma_log_hazard ~ cauchy(0, 1);
log_hazard ~ cauchy(0, 10);
beta ~ cauchy(0, 10);
for (t in 2:T) {
log_hazard[t] ~ normal(log_hazard[t-1], sigma_log_hazard);
}
for (n in 1:N) {
if (Cens[n] == 1) {
target += log_f[GroupID[n], TimeID[n]];
}
else {
target += log_S[GroupID[n], TimeID[n]];
}
}
}

generated quantities{
vector[T] hazard;
vector[T] log_S0;
vector[T] S0;
vector[K] exp_beta;
hazard = exp(log_hazard);
log_S0[1] = log1m_exp(log_hazard[1]) * (Time[1]);
for (t in 2:T) {
log_S0[t] = log_S0[t-1] + log1m_exp(log_hazard[t]) * (Time[t] - Time[t-1]);
}
S0 = exp(log_S0);
exp_beta = exp(beta);
}



分析対象データ

パッケージsurvivalに含まれているkidneyという生存時間データを分析する。

説明変数としてsexとdiseaseを使用する。

> head(kidney)

id time status age sex disease frail
1 1 8 1 28 1 Other 2.3
2 1 16 1 28 1 Other 2.3
3 2 23 1 48 2 GN 1.9
4 2 13 0 48 2 GN 1.9
5 3 22 1 32 1 Other 1.2
6 3 28 1 32 1 Other 1.2
> nrow(kidney)
[1] 76


時刻情報とグループ情報の抽出

df_kidney = kidney

df_kidney$sex = c("M","F")[kidney$sex]
df_kidney$sex = as.factor(df_kidney$sex)

dum = cbind(
model.matrix(~df_kidney$sex)[,-1,drop=F],
model.matrix(~df_kidney$disease)[,-1]
)
colnames(dum) = c("sexM","diseaseGN","diseaseAN","diseasePKD")

times = sort(unique(df_kidney$time))
time_id = match(df_kidney$time,times)

group = unique(dum)
rownames(group) = 1:nrow(group)
group_char = apply(group,1,paste,collapse="")
group_id = match(apply(dum,1,paste,collapse=""),group_char)

> times

[1] 2 4 5 6 7 8 9 12 13 15 16 17 22 23 24 25 26 27 28 30 34 38 39
[24] 40 43 46 53 54 58 63 66 70 78 96 108 113 114 119 130 132 141 149 152 154 156 159
[47] 177 185 190 196 201 245 292 318 333 402 447 511 536 562
> time_id
[1] 6 11 14 9 13 19 57 54 20 8 15 52 5 7 58 20 27 50 10 44 5 55 41 6 34 22 42 32 59 16 12
[32] 2 48 47 53 37 13 46 10 35 43 60 56 15 9 31 23 26 8 24 36 51 40 45 21 20 1 16 39 17 18 29
[63] 3 25 43 20 49 3 38 6 28 11 4 33 30 6
> group
sexM diseaseGN diseaseAN diseasePKD
1 1 0 0 0
2 0 1 0 0
3 0 0 0 0
4 1 1 0 0
5 0 0 1 0
6 1 0 1 0
7 1 0 0 1
8 0 0 0 1
> group_id
[1] 1 1 2 2 1 1 3 3 1 1 3 3 4 4 2 2 5 5 4 4 5 5 3 3 5 5 5 5 3 3 6 6 3 3 3 3 2 2 3 3 7 7 3 3 5 5 5
[48] 5 6 6 5 5 2 2 5 5 4 4 2 2 5 5 5 5 8 8 2 2 3 3 3 3 8 8 7 7


Stanの実行

model_stan = stan_model("survival_4.stan")

data_stan = list(
N = nrow(df_kidney),
T = length(times),
G = nrow(group),
K = ncol(group),
TimeID = time_id,
Cens = df_kidney$status,
GroupID = group_id,
Time = times,
Group = group
)

tstart = proc.time()
stan_fit = sampling(model_stan,data=data_stan,
chains=4, iter=4000, warmup=2000, thin=1, seed=3,
control=list(adapt_delta=0.8,max_treedepth=10)
)
proc.time() - tstart # 122sec


実行結果の抽出

生存関数と瞬間死亡率、説明変数の効果$\beta$の95%信用区間を抽出する。

瞬間死亡率の時刻情報を推定した区間の中央に設定する。

my_extract = function(x,name){

tb = summary(x, probs=c(0.025,0.975),pars=name)$summary
df = data.frame(tb[,c("mean","2.5%","97.5%"), drop=F])
colnames(df) = c("mean","lower","upper")
df
}

df_S0 = my_extract(stan_fit,"S0")
df_S0$time = times

df_h = my_extract(stan_fit,"hazard")
df_h$time = (function(){
n = length(times)
(times + c(0,times[-n]) + 1)/2
})()

df_exp_beta = my_extract(stan_fit,"exp_beta")
rownames(df_exp_beta) = c("sexM","diseaseGN","diseaseAN","diseasePKD")

beta_hist = extract(stan_fit,pars="exp_beta")$exp_beta
beta_hist = data.frame(beta_hist)
colnames(beta_hist) = c("sexM","diseaseGN","diseaseAN","diseasePKD")


グラフの作成

推定された説明変数の効果$\beta$の分布。

pvline = geom_vline(xintercept=1,color="red")

plot1 = ggplot(beta_hist,aes(x=sexM)) + geom_histogram(bins=30) + pvline
plot2 = ggplot(beta_hist,aes(x=diseaseGN)) + geom_histogram(bins=30) + pvline
plot3 = ggplot(beta_hist,aes(x=diseaseAN)) + geom_histogram(bins=30) + pvline
plot4 = ggplot(beta_hist,aes(x=diseasePKD)) + geom_histogram(bins=30) + pvline

grid.arrange(plot1,plot2,plot3,plot4,ncol=2,nrow=2)

fit_cox1.png

coxph関数による効果$\beta$の推定値との比較。

kidney.cox = coxph(Surv(time, status) ~ sex+disease, data=df_kidney)

summary(kidney.cox)
df_exp_beta

> kidney.cox = coxph(Surv(time, status) ~ sex+disease, data=df_kidney)

> summary(kidney.cox)
Call:
coxph(formula = Surv(time, status) ~ sex + disease, data = df_kidney)

n= 76, number of events= 58

coef exp(coef) se(coef) z Pr(>|z|)
sexM 1.4774 4.3815 0.3569 4.140 3.48e-05 ***
diseaseGN 0.1392 1.1494 0.3635 0.383 0.7017
diseaseAN 0.4132 1.5116 0.3360 1.230 0.2188
diseasePKD -1.3671 0.2549 0.5889 -2.321 0.0203 *
---
Signif. codes: 0 *** 0.001 ** 0.01 * 0.05 . 0.1 1

exp(coef) exp(-coef) lower .95 upper .95
sexM 4.3815 0.2282 2.17689 8.8188
diseaseGN 1.1494 0.8700 0.56368 2.3437
diseaseAN 1.5116 0.6616 0.78245 2.9202
diseasePKD 0.2549 3.9238 0.08035 0.8084

Concordance= 0.696 (se = 0.045 )
Rsquare= 0.206 (max possible= 0.993 )
Likelihood ratio test= 17.56 on 4 df, p=0.001501
Wald test = 19.77 on 4 df, p=0.0005533
Score (logrank) test = 19.97 on 4 df, p=0.0005069

> df_exp_beta
mean lower upper
M 4.3978555 2.06594992 8.0631262
diseaseGN 0.9693798 0.45369567 1.7495952
diseaseAN 1.3627789 0.68891947 2.4198428
diseasePKD 0.2758943 0.08059133 0.6748802

df_cox = data.frame(summary(kidney.cox)$conf.int)

colnames(df_cox) = c("mean","inv_mean","lower","upper")
df_cox$x = factor(1:4,levels=1:4,labels=c("sexM","diseaseGN","diseaseAN","diseasePKD"))
df_exp_beta$x = (1:4)+0.1

ggplot() +
geom_pointrange(data=df_cox,aes(x=x,y=mean,ymin=lower,ymax=upper),size=0.5) +
geom_pointrange(data=df_exp_beta,aes(x=x,y=mean,ymin=lower,ymax=upper),size=0.5,linetype=2) +
labs(x="",y="beta")

実線がcoxph関数で推定した結果、破線がStanで推定した結果。

fit_cox2.png

推定された生存関数と瞬間死亡率のグラフ。

plot_S = ggplot() +

geom_line(data=df_S0,aes(x=time,y=mean),size=0.5,linetype=2) +
geom_ribbon(data=df_S0,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F) +
labs(y="surv") +
scale_y_continuous(labels = scales::percent)

plot_h = ggplot() +
geom_line(data=df_h,aes(x=time,y=mean),size=0.5,linetype=2) +
geom_ribbon(data=df_h,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F) +
labs(y="hazard") +
scale_y_continuous(labels = scales::percent)

grid.arrange(plot_h,plot_S,ncol=2)

fit_cox0.png

coxph関数によるベースラインとの比較。

kidney.fit = survfit(kidney.cox, newdata = data.frame(sex="F", disease="Other"))

# rough estimation of hazard from survival function
df_kphaz = (function(){
n = length(kidney.fit$surv)
times = (kidney.fit$time[2:n] + kidney.fit$time[1:(n-1)])/2
interval = (kidney.fit$time[2:n] - kidney.fit$time[1:(n-1)])
ratio = kidney.fit$surv[2:n] / kidney.fit$surv[1:(n-1)]
data.frame(
time = times,
hazard = (1 - ratio) / interval
)
})()

plot_S = autoplot(kidney.fit) +
geom_line(data=df_S0,aes(x=time,y=mean),size=0.5,linetype=2) +
geom_ribbon(data=df_S0,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F)

plot_h = ggplot() +
geom_line(data=df_kphaz,aes(x=time,y=hazard),size=0.5) +
geom_line(data=df_h,aes(x=time,y=mean),size=0.5,linetype=2) +
geom_ribbon(data=df_h,aes(x=time,ymin=lower,ymax=upper),alpha=0.3,inherit.aes=F) +
labs(y="hazard") +
scale_y_continuous(labels = scales::percent)

grid.arrange(plot_h,plot_S,ncol=2)

実線がcoxph関数で推定した結果、破線がStanで推定した結果。

fit_cox3.png


おわりに


  • geom_pointrange関数でx座標を上手くずらす方法を見つけられなかった。

  • stan_ac関数で確認すると、betaの自己相関は急速に減少しているが、sigma_log_hazardとlp__については、なかなか減少していない。Rhatは全て1.1以下となっている。

  • Stanの実行パラメータはデフォルトのままなので調整の余地はある。

v = c("sigma_log_hazard","beta","lp__")

stan_ac(stan_fit,pars=v)

fit_cox4.png

print(stan_fit,digits=3,pars=v,probs=c(0.025,0.5,0.975))

Inference for Stan model: survival_4.

4 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=8000.

mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
sigma_log_hazard 0.203 0.012 0.109 0.053 0.185 0.452 89 1.059
beta[1] 1.422 0.005 0.345 0.726 1.428 2.087 5849 1.001
beta[2] -0.090 0.006 0.346 -0.790 -0.081 0.559 3299 1.001
beta[3] 0.257 0.007 0.324 -0.373 0.257 0.884 2046 1.002
beta[4] -1.431 0.008 0.545 -2.518 -1.414 -0.393 4618 1.000
lp__ -174.200 2.727 31.170 -226.916 -177.323 -106.646 131 1.071