LoginSignup
6

More than 5 years have passed since last update.

posted at

updated at

Organization

R6 をハックして dplyr + PostgreSQL で移動平均を計算する #rstatsj

dplyr の window function の説明を見ても、移動平均の計算の仕方がよく分かりませんでした。

たぶんまだ対応していないのでしょう。

ところで、PostgreSQL では、次のような SQL を投げれば移動平均は計算できます。

Compute a rolling average of games player:

MEAN(G) OVER (PARTITION BY playerID ORDER BY yearID BETWEEN 2 PRECEEDING AND 2 FOLLOWING)

でもでも、せっかく dplyr 使ってるのに、生 SQL 投げるのって、嫌ですよね。
どうにかして dplyr で移動平均を計算できないでしょうか?

まず、次のようなテーブルを用意します。

R
glimpse(table)
結果
Observations: 54
Variables:
$ date  (date) 2010-08-01, 2010-09-01, 2010-10-01, 2010-11-01, 2010-12-01, ...
$ value (dbl) 3171.130, 7861.920, 1017.060, 430.910, 1998.000, 1537.720, 10...

この value の移動平均を求めたいのですが、次のようにすると、

R
q <- table %>% mutate(moving_average=mean(value))
q$query
結果
<Query> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
<PostgreSQLConnection:(8288,0)> 

このようなクエリが出来上がります。(このトップレベルの SELECT は必要なのか?)

このクエリの UNBOUNDED の部分を数値に置き換えれば移動平均が計算できそうです。

OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)

クエリオブジェクトのクラスを調べてみると、

R
class(q$query)
結果
[1] "Query" "R6"

今話題の R6 ではありませんか!
さっそく print してみると、

R
print(q$query)
結果
<Query> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
<PostgreSQLConnection:(8288,0)> 

内部がどうなっているは隠蔽されています。

これでお手上げ。。となりそうですが、R6 が環境であるということが分かっているので、次のようなハックが可能です。

R
ls(envir = q$query)
結果
[1] "con"         "fetch"       "fetch_paged" "initialize"  "ncol"        "nrow"        "print"       "sql"         "vars"

ふむふむ、この sql という変数が怪しいですね。

R
print(q$query$sql)
結果
<SQL> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"

にらんだ通りです。それでは sql のクラスを確認してみましょう。

R
class(q$query$sql)
結果
[1] "sql"       "character"

おおー! ただの character みたいです! これはいける!

sql を書き換えてしまいましょう。

R
q$query$sql <- q$query$sql %>% str_replace_all("UNBOUNDED", 3)
q$query
結果
<Query> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
<PostgreSQLConnection:(8288,0)> 

キタ━━(゚∀゚)━━!! クエリオブジェクトが書き換わりました!

このクエリを DB に投げてみると、

R
q %>% collect
結果
Source: local data frame [54 x 3]

         date    value moving_average
1  2010-08-01 3171.130      3120.2539
2  2010-09-01 7861.920      2895.8023
3  2010-10-01 1017.060      2669.4554
4  2010-11-01  430.910      2438.4301
5  2010-12-01 1998.000      2131.3890
6  2011-01-01 1537.720      1064.9248
7  2011-02-01 1052.280       989.3451
8  2011-03-01 1021.840      1044.0888
9  2011-04-01  396.667       930.9410
10 2011-05-01  487.999       846.7288
..        ...      ...            ...

移動平均が求まりました!

まとめ

次のようにすれば、dplyr で移動平均を求めることが可能です。

R
glimpse(table)
q <- table %>% mutate(moving_average=mean(value))
q$query$sql <- q$query$sql %>% str_replace_all("UNBOUNDED", 3)
q %>% collect

RcppRoll と速度比較してみましょう。

R
library(RcppRoll)
library(microbenchmark)

only_dplyr <- function() {
  q <- table %>%
    mutate(moving_average=mean(value))
  q$query$sql <- q$query$sql %>% str_replace_all("UNBOUNDED", 3)
  data <- q %>% collect
  data
}
rcpp_roll <- function() {
  q <- table
  data <- q %>% collect
  data <- data %>% mutate(moving_average=roll_mean(value, 7, fill=NA))
  data
}

microbenchmark(
  only_dplyr(),
  rcpp_roll()
)
R
Unit: milliseconds
         expr       min        lq      mean    median        uq       max neval
 only_dplyr() 16.105506 17.338677 17.932148 17.830893 18.439324 21.690620   100
  rcpp_roll()  4.321732  4.523276  4.845213  4.645886  4.843053  8.093333   100

RcppRoll はええ! orz

関連

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
6