1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RとStanで実装するMMM

Last updated at Posted at 2024-08-25

はじめに

こんにちは、事業会社で働くデータサイエンティストです。

最近、業務を通じてMMM(Marketing Mix Modeling)という概念を知りました。MMMについてはその実装方法がネット上に多く存在します。ただ、その実装方法はPythonによるものがほとんどです。そこで、今回は主にPyMC-Marketingを参照し、MMMをRとStanで実装しました。

モデル構造

Jin et al.(2017)に基づき、以下のモデルに従う売上$Y_t$を考える。

$Y_t = Intercept + \sum_{m=1}^M \beta_m\times Hill(T_{t,m}^*;K_m,S_m) + \sum_{i=1}^N\gamma_i Z_{t,i} + Trend_t + Seasonality_t + \epsilon_t$

$t$は時間、$m$はメディア、$i$は説明変数を表している。
$Hill(T_{t,m}^*;K_m,S_m)$はヒル関数で、

$Hill(T_{t,m};K_m,S_m)=\frac{1} {1 + (\frac {T_{t,m}}{K_m})^{-S_m}},$

$where \quad T_{t,m}\geq0.$

また、$T_{t,m}^*$はアドストック関数で、

$T_{t,m}^* = adstock(T_{t-L+1,m}, \cdots, T_{t,m};w_m,L) = \frac {\sum_{l = 0}^{L-1}w_m(l)\times T_{t-l}} {\sum_{l=0}^{L-1}w_m(l)}.$

準備として、アドストック関数の分母を用意しておく。

#アドストック関数の分母
denominator <- function(alpha, theta, L) {
  l <- 0:L-1
  total_weight <- sum(alpha^((l - theta)^2))
  return(total_weight)
}

$w_m(l)$は重み関数で、

$w_m(l; \alpha_n, \theta_m) = \alpha_m^{(l- \theta_m)^2},$

$where \quad l = 0,\cdots L-1, ~ 0 < \alpha_m < 1, ~ 0_m \leq \theta \leq L-1.$

重み関数を以下のように定義しておく。ここでは、複数のメディアについて同一の重み関数を用いる。

#重み関数を定義

weight_function <- function(l, alpha = 0.2, theta = 0) {
  return(alpha^((l - theta)^2))
}

とくに、今回は2つのメディアAとメディアBが売上に与える影響を考えるので$m=2$である。また、今回は共変量$Z_{t,i}$の存在は無視する。現実には、ある特定の週に一時的なショックが起こるといったことが考えられるだろう。

Rで実装

売上データの生成

2020年4月1日から2024年3月31日までの各週の売上データを作成する。

df_sales <- data.frame(date = seq(from = as.Date("2020-04-01"),
                                  to = as.Date("2023-09-01"),
                                  by = "week")) |>
  dplyr::mutate(id = seq(from = 1,
                         to = dplyr::n())) |>

メディアAとメディアBの効果をそれぞれ追加する。

  #メディアA
  dplyr::mutate(a = runif(n = dplyr::n(), min = 0, max = 1),
                a = dplyr::case_when(
                  media_a > 0.9 ~ a,
                  TRUE ~ a/2
                )) |>
  dplyr::mutate(sum_a = sum(a)) |>
  dplyr::mutate(a1 = dplyr::lag(a,1),
                a2 = dplyr::lag(a,2),
                a3 = dplyr::lag(a,3),
                a4 = dplyr::lag(a,4),
                a5 = dplyr::lag(a,5),
                a6 = dplyr::lag(a,6),
                a7 = dplyr::lag(a,7)) |>
  dplyr::mutate(dplyr::across(c(a, a1, a2, a3, a4, a5, a6, a7), 
                              ~tidyr::replace_na(., 0))) |>
  dplyr::mutate(adstock_a_numerator = a * weight_function(l = 0) +
                  a1 * weight_function(l = 1) +
                  a2 * weight_function(l = 2) +
                  a3 * weight_function(l = 3) +
                  a4 * weight_function(l = 4) +
                  a5 * weight_function(l = 5) +
                  a6 * weight_function(l = 6) +
                  a7 * weight_function(l = 7)) |>
  dplyr::mutate(adstock_a_denominator = denominator(alpha = 0.2, theta = 0, L = 8)) |>
  dplyr::mutate(adstock_a = adstock_a_numerator/adstock_a_denominator) |>
  dplyr::mutate(hill_a = 1/(1 + ((adstock_a/0.3)^(-0.5)))) |>
  #メディアB
  dplyr::mutate(b = runif(n = dplyr::n(), min = 0, max = 1),
                b = dplyr::case_when(
                  b > 0.9 ~ b,
                  TRUE ~ b/2
                )) |>
  dplyr::mutate(sum_b = sum(b)) |>
  dplyr::mutate(b1 = dplyr::lag(b,1),
                b2 = dplyr::lag(b,2),
                b3 = dplyr::lag(b,3),
                b4 = dplyr::lag(b,4),
                b5 = dplyr::lag(b,5),
                b6 = dplyr::lag(b,6),
                b7 = dplyr::lag(b,7)) |>
  dplyr::mutate(dplyr::across(c(b, b1, b2, b3, b4, b5, b6, b7), 
                              ~tidyr::replace_na(., 0))) |>
  dplyr::mutate(adstock_b_numerator = b * weight_function(l = 0) +
                  b1 * weight_function(l = 1) +
                  b2 * weight_function(l = 2) +
                  b3 * weight_function(l = 3) +
                  b4 * weight_function(l = 4) +
                  b5 * weight_function(l = 5) +
                  b6 * weight_function(l = 6) +
                  b7 * weight_function(l = 7)) |>
  dplyr::mutate(adstock_b_denominator = denominator(alpha = 0.2, theta = 0, L = 8)) |>
  dplyr::mutate(adstock_b = adstock_b_numerator/adstock_b_denominator) |>
  dplyr::mutate(hill_b = 1/(1 + ((adstock_b/0.3)^(-0.5)))) |>

さらに、季節性とトレンドを追加する。

  dplyr::mutate(day_of_year = lubridate::yday(date),
                seasonality = 0.5*(-sin(2*2*pi*day_of_year/365.5) + 
                                     cos(1*2*pi*day_of_year/365.5))) |>
  dplyr::mutate(trend0 = seq(from = 0,
                             to = 50,
                             length.out = dplyr::n()),
                trend = (trend0 + 10)^(1/4) - 1) 

最後に切片を含めた係数と誤差項を生成し、変数を足し合わせることで売上が生成される。

dplyr::mutate(intercept = 2.0) |>
  dplyr::mutate(beta_a = 3.0,
                beta_b = 2.0) |>
  dplyr::mutate(epsilon = rnorm(n = dplyr::n(), mean = 0, sd = 0.25)) |>
  dplyr::mutate(sales = intercept + trend + seasonality + event1 + event2 + beta_a * hill_a + beta_b*hill_b + epsilon)

売上データの可視化

{ggplot2}を使って可視化した。いま、売上額の単位は重要でないので無視している。

ggplot2::ggplot(df_sales, ggplot2::aes(x = date, y = sales)) +
  ggplot2::geom_line() +
  ggplot2::scale_x_date(date_labels = "%y-%m", date_breaks = "1 month") +
  ggplot2::theme_bw() +
  ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 0.8)) 

sales.png


### 各メディアのROASを確認
各メディアのROAS(Return on Advetising Spend)を確認する。

```R
roas <- df_sales |>
  dplyr::group_by() |>
  dplyr::summarise(a_contribution = sum(beta_a*hill_a),
                   a_sum = sum(a),
                   b_contribution = sum(beta_b*hill_b),
                   b_sum = sum(b)) |>
  dplyr::ungroup() |>
  dplyr::mutate(a_roas = a_contribution/a_sum,
                b_roas = b_contribution/b_sum)

メディアAのROASが約4.1なのに対し、メディアBのROASは約2.8であることがわかる。

> c(roas$a_roas, roas$b_roas)
4.181607 2.830847

Stanで実装

上記の売上をStanでモデリングする。
まずはdataブロックから。

data {
  int T;        // データ取得期間の長さ
  vector[T] y;  // 観測値
  vector[T] hill_a;
  vector[T] hill_b;
}

続いてparametersブロック。

parameters {
  vector[T] intercept;
  vector[T] trend;       // 水準+ドリフト成分の推定値
  vector[T] seasonality;    // 季節成分の推定値
  vector[T] beta_a;
  vector[T] beta_b;
  
  real<lower=0> sigma;
  real<lower=0> sigma_intercept;
  real<lower=0> sigma_trend;
  real<lower=0> sigma_seasonality;
}

transformed parametersブロックを以下のように設定する。

transformed parameters {
  vector[T] alpha;        // 各成分の和として得られる状態推定値
  vector[T] contribution_a; // メディアAの寄与率
  vector[T] contribution_b; // メディアBの寄与率
  
  for(i in 1:T) {
    alpha[i] = intercept[i] + beta_a[i]*hill_a[i] + beta_b[i]*hill_b[i] + trend[i] + seasonality[i];
  }
}

最後に、modelブロックを設定する。

model {
  sigma_intercept ~ gamma(0.1, 0.1);
  sigma_seasonality ~ gamma(0.1, 0.1);
  sigma_trend ~ gamma(0.1, 0.1);
  
  for (i in 2:T) {
    intercept[i] ~ normal(intercept[i-1], sigma_intercept);
    beta_a[i] ~ normal(beta_a[i-1], 1);
    beta_b[i] ~ normal(beta_b[i-1], 1);
    trend[i] ~ normal(trend[i-1], sigma_trend);
  }
  
  for(i in 1:T) {
    seasonality[i] ~ normal(0.5 * (-sin(2 * 2 * pi() * day_of_year[i] / 365.5) + 
                            cos(1 * 2 * pi() * day_of_year[i] / 365.5)), sigma_seasonality);
    y[i] ~ normal(alpha[i], sigma);
  }
}

これでStanファイルを書くことができたので、MCMCサンプリングに移る。

RでMCMCサンプリング

始めに、コンパイルと計算の並列化を行っておく。

## 計算の高速化
rstan::rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

リスト形式でデータをまとめる。

data_list <- list(
  T = nrow(df_sales),
  y = df_sales$sales, 
  hill_a = df_sales$hill_a,
  hill_b = df_sales$hill_b,
  day_of_year = df_sales$day_of_year
)

MCMCを実行し、そのサンプルを取り出す。

mcmc_result <- rstan::stan(
  file = "MMM.stan",
  data = data_list,
  seed = 1
)

mcmc_sample <- rstan::extract(mcmc_result, permuted = T)

取り出したサンプルから、売上を再度構成する。

sales_post <- df_sales |>
  dplyr::select(c(date, sales, hill_a, hill_b)) |>
  dplyr::mutate(beta_a = asplit(mcmc_sample$beta_a, 2) |>
                  purrr::map_dbl(mean),
                beta_b = asplit(mcmc_sample$beta_b, 2) |>
                  purrr::map_dbl(mean),
                trend = asplit(mcmc_sample$trend, 2) |>
                  purrr::map_dbl(mean),
                seasonality = asplit(mcmc_sample$seasonality, 2) |>
                  purrr::map_dbl(mean),
                intercept = asplit(mcmc_sample$intercept, 2) |>
                  purrr::map_dbl(mean)) |>
  dplyr::mutate(sales_post = intercept +  beta_a*hill_a + beta_b*hill_b + trend + seasonality,
                trend_seasonality = intercept + trend + seasonality,
                #トレンド成分と季節成分の可視化のため
                a = trend_seasonality + beta_a*hill_a)
                #メディアAの可視化のため

再構成した売上を可視化する。

ggplot2::ggplot(sales_post) +
  ggplot2::geom_point(ggplot2::aes(x = date, y = sales), color = "red") +
  ggplot2::geom_line(ggplot2::aes(x = date, y = sales_post), color = "blue") +
  ggplot2::scale_x_date(date_breaks = "1 month", date_labels = "%Y-%m") +
  ggplot2::theme_bw() +
  ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))

sales_post.png

赤色の点が実際の売上額、青線がStanによるモデリングで生成された売上である。Stanのモデリングがうまくいっていることがわかる。

それでは、各メディアがどれだけ売上に貢献しているのか可視化してみよう。

ggplot2::ggplot(sales_post) +
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = 0, ymax = trend_seasonality, fill = "トレンドと季節成分"), alpha = 0.8) + 
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = trend_seasonality, ymax = a, fill = "メディアA"), alpha = 0.8) +
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = a, ymax = sales_post, fill = "メディアB"), alpha = 0.8) + 
  ggplot2::scale_fill_manual(values = c("トレンドと季節成分" = "black", "メディアA" = "blue", "メディアB" = "red"),
                             breaks = c("メディアB", "メディアA", "トレンドと季節成分")) +
  ggplot2::labs(fill = "成分") +
  ggplot2::scale_x_date(date_breaks = "1 month", date_labels = "%Y-%m") +
  ggplot2::theme_bw() +
  ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))

media_contribution.png

時間を通じて、売上に占める割合はメディアBよりもメディアAの方が大きいと分かる。

RでROASの推定

最後に、得られたMCMCサンプルから各メディアのROASを推定する。

MCMCサンプリングからメディアAとメディアBの係数を取り出しておく(同時に、各メディアの広告費用をかけておく)。

a_contribution <- mcmc_sample$beta_a*df_sales$hill_a 
b_contribution <- mcmc_sample$beta_b*df_sales$hill_b

その上で、データフレームを作成する。

#可視化のためのデータフレームの作成
roas_estimates <- data.frame(sum_a_contribution = purrr::map_dbl(1:nrow(a_contribution), ~ sum(a_contribution[., ])),
                             sum_b_contribution = purrr::map_dbl(1:nrow(b_contribution), ~ sum(b_contribution[., ]))) |>
  dplyr::mutate(a = seq(from = df_sales$sum_a[1],
                         to = df_sales$sum_a[1]),
                b = seq(from = df_sales$sum_b[1],
                          to = df_sales$sum_b[1]),
                roas_a = sum_a_contribution/a,
                roas_b = sum_b_contribution/b) |>
  tidyr::pivot_longer(cols = c(roas_a, roas_b),
                      names_to = "media",
                      values_to = "rate")
  #あとで可視化しやすくするため、pivot_longerを用いる。
ggplot2::ggplot(roas_estimates, ggplot2::aes(x = rate)) +
  ggplot2::geom_histogram() +
  ggplot2::facet_wrap(~media, nrow = 2) +
  ggplot2::geom_vline(
    data = roas_estimates |>
      dplyr::group_by(media) |>
      dplyr::summarize(mean_rate = mean(rate, na.rm = TRUE)),
    ggplot2::aes(xintercept = mean_rate),
    color = "red", linetype = "dashed", size = 1) +
  ggplot2::scale_x_continuous(limits = c(0,8), n.breaks = 8) +
  ggplot2::theme_bw()

roas_post.png

(分布が本当にそうなのか怪しいが、)メディアAのROASの平均は約4.5で、生成された売上データで確認した4.1よりやや高くなった。一方、メディアBのROASの平均は約2.3で、生成データの2.8より少し小さくなっている。

参考文献

金本拓. 2024. 「因果推論 基礎から機械学習・時系列分析・因果探索を用いた意思決定のアプローチ」. オーム社.
MMM Example Notebook. PyMC-Marketing 0.8.0. https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html (2024年8月27日最終閲覧).
Jin et al. 2017. A Hierarchical Bayesian Approach to Improve Media Mix Models Using Category Data. Google Inc.

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?