dplyr の window function の説明を見ても、移動平均の計算の仕方がよく分かりませんでした。
たぶんまだ対応していないのでしょう。
ところで、PostgreSQL では、次のような SQL を投げれば移動平均は計算できます。
Compute a rolling average of games player:
MEAN(G) OVER (PARTITION BY playerID ORDER BY yearID BETWEEN 2 PRECEEDING AND 2 FOLLOWING)
でもでも、せっかく dplyr 使ってるのに、生 SQL 投げるのって、嫌ですよね。
どうにかして dplyr で移動平均を計算できないでしょうか?
まず、次のようなテーブルを用意します。
glimpse(table)
Observations: 54
Variables:
$ date (date) 2010-08-01, 2010-09-01, 2010-10-01, 2010-11-01, 2010-12-01, ...
$ value (dbl) 3171.130, 7861.920, 1017.060, 430.910, 1998.000, 1537.720, 10...
この value
の移動平均を求めたいのですが、次のようにすると、
q <- table %>% mutate(moving_average=mean(value))
q$query
<Query> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
<PostgreSQLConnection:(8288,0)>
このようなクエリが出来上がります。(このトップレベルの SELECT は必要なのか?)
このクエリの UNBOUNDED
の部分を数値に置き換えれば移動平均が計算できそうです。
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
クエリオブジェクトのクラスを調べてみると、
class(q$query)
[1] "Query" "R6"
今話題の R6
ではありませんか!
さっそく print
してみると、
print(q$query)
<Query> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
<PostgreSQLConnection:(8288,0)>
内部がどうなっているは隠蔽されています。
これでお手上げ。。となりそうですが、R6 が環境であるということが分かっているので、次のようなハックが可能です。
ls(envir = q$query)
[1] "con" "fetch" "fetch_paged" "initialize" "ncol" "nrow" "print" "sql" "vars"
ふむふむ、この sql
という変数が怪しいですね。
print(q$query$sql)
<SQL> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
にらんだ通りです。それでは sql
のクラスを確認してみましょう。
class(q$query$sql)
[1] "sql" "character"
おおー! ただの character
みたいです! これはいける!
sql
を書き換えてしまいましょう。
q$query$sql <- q$query$sql %>% str_replace_all("UNBOUNDED", 3)
q$query
<Query> SELECT "date", "value", "moving_average"
FROM (SELECT "date", "value", avg("value") OVER (ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING) AS "moving_average"
FROM "ahlcnnxsht") AS "_W20"
<PostgreSQLConnection:(8288,0)>
キタ━━(゚∀゚)━━!! クエリオブジェクトが書き換わりました!
このクエリを DB に投げてみると、
q %>% collect
Source: local data frame [54 x 3]
date value moving_average
1 2010-08-01 3171.130 3120.2539
2 2010-09-01 7861.920 2895.8023
3 2010-10-01 1017.060 2669.4554
4 2010-11-01 430.910 2438.4301
5 2010-12-01 1998.000 2131.3890
6 2011-01-01 1537.720 1064.9248
7 2011-02-01 1052.280 989.3451
8 2011-03-01 1021.840 1044.0888
9 2011-04-01 396.667 930.9410
10 2011-05-01 487.999 846.7288
.. ... ... ...
移動平均が求まりました!
まとめ
次のようにすれば、dplyr で移動平均を求めることが可能です。
glimpse(table)
q <- table %>% mutate(moving_average=mean(value))
q$query$sql <- q$query$sql %>% str_replace_all("UNBOUNDED", 3)
q %>% collect
RcppRoll
と速度比較してみましょう。
library(RcppRoll)
library(microbenchmark)
only_dplyr <- function() {
q <- table %>%
mutate(moving_average=mean(value))
q$query$sql <- q$query$sql %>% str_replace_all("UNBOUNDED", 3)
data <- q %>% collect
data
}
rcpp_roll <- function() {
q <- table
data <- q %>% collect
data <- data %>% mutate(moving_average=roll_mean(value, 7, fill=NA))
data
}
microbenchmark(
only_dplyr(),
rcpp_roll()
)
Unit: milliseconds
expr min lq mean median uq max neval
only_dplyr() 16.105506 17.338677 17.932148 17.830893 18.439324 21.690620 100
rcpp_roll() 4.321732 4.523276 4.845213 4.645886 4.843053 8.093333 100
RcppRoll
はええ! orz