概要
少し前に{tensorflow}というRからTensorFlowを使うパッケージがRStudio社から公開されたので、みんな大好きirisデータの分類をMNISTの例を参考に試してみました。
事前準備
実行環境によっては以前にインストールしたバージョンを削除して再度インストールする必要があります(自分はprotobufも再インストールしました)。
NOTE: If you are upgrading from a previous installation of TensorFlow < 0.7.1, you should uninstall the previous TensorFlow and protobuf using pip uninstall first to make sure you get a clean installation of the updated protobuf dependency.
https://www.tensorflow.org/versions/r0.11/get_started/os_setup.html
# 削除したprotobufを再インストール
$ pip install protobuf
# OS XでCPU利用の場合
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.11.0rc1-py2-none-any.whl
$ sudo pip install --upgrade $TF_BINARY_URL
# {tensorflow}のインストール時に{tensorflow}で利用するPythonを指定するため、パスを確認しておく
$ which python
/usr/local/bin/python
export PYTHON_PATH=`which python`
# {tensorflow}で利用するPythonのパスを環境変数に設定してインストールする
Sys.setenv(TENSORFLOW_PYTHON = "/usr/local/bin/python")
devtools::install_github("rstudio/tensorflow")
インストール後にパッケージを呼び出してエラーが起きなければOK(エラーが起きる場合は使用するライブラリの再インストールや、上記のPythonパスを確認するなど試す)。なお、エラー後に再度パッケージを読み込むとエラーメッセージは表示されないが、使用できないので注意。
定義・設定
処理で利用するライブラリの読み込みや定数・関数。加えて今回はTensorFlowのモデルの定義をここで行う。
library(tensorflow)
library(dplyr)
library(foreach)
library(caret)
SET_CV_NUM <- 5
SET_DATA_PARAM <- list(
CLASS_NUM = 3L, FEATURE_NUM = 4L
)
SET_SETP_NUM <- 3000
# confusion matrixからAccuracyを計算
calcAccuracy <- function(confusion_mat) {
return(sum(diag(x = confusion_mat)) / sum(confusion_mat))
}
# MNIST For ML Beginnersの例を参考
W <- tensorflow::tf$Variable(
initial_value = tensorflow::tf$zeros(shape = tensorflow::shape(SET_DATA_PARAM$FEATURE_NUM, SET_DATA_PARAM$CLASS_NUM))
)
x <- tensorflow::tf$placeholder(
dtype = tensorflow::tf$float32,
shape = tensorflow::shape(NULL, SET_DATA_PARAM$FEATURE_NUM)
)
b <- tensorflow::tf$Variable(
initial_value = tensorflow::tf$zeros(shape = tensorflow::shape(SET_DATA_PARAM$CLASS_NUM))
)
y <- tensorflow::tf$nn$softmax(logits = tensorflow::tf$matmul(a = x, b = W) + b)
y_ <- tensorflow::tf$placeholder(dtype = tensorflow::tf$float32, shape = tensorflow::shape(NULL, SET_DATA_PARAM$CLASS_NUM))
# 損失関数とオプティマイザーの設定
cross_entropy <- tensorflow::tf$reduce_mean(input_tensor = - tensorflow::tf$reduce_sum(input_tensor = y_ * tensorflow::tf$log(x = y), reduction_indices = 1L))
optimizer <- tensorflow::tf$train$GradientDescentOptimizer(learning_rate = 0.5)
train_step <- optimizer$minimize(loss = cross_entropy)
# 評価用
correct_prediction <- tensorflow::tf$equal(
x = tensorflow::tf$argmax(input = y, dimension = 1L), y = tensorflow::tf$argmax(input = y_, dimension = 1L)
)
accuracy <- tensorflow::tf$reduce_mean(
input_tensor = tensorflow::tf$cast(x = correct_prediction, dtype = tensorflow::tf$float32)
)
実行部
# irisデータ
data(iris)
> iris %>%
head(n = 10)
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa
7 4.6 3.4 1.4 0.3 setosa
8 5.0 3.4 1.5 0.2 setosa
9 4.4 2.9 1.4 0.2 setosa
10 4.9 3.1 1.5 0.1 setosa
> iris %>%
dplyr::group_by(Species) %>%
dplyr::summarise_all(.fun = mean)
# A tibble: 3 × 5
Species Sepal.Length Sepal.Width Petal.Length Petal.Width
<fctr> <dbl> <dbl> <dbl> <dbl>
1 setosa 5.006 3.428 1.462 0.246
2 versicolor 5.936 2.770 4.260 1.326
3 virginica 6.588 2.974 5.552 2.026
set.seed(seed = 71)
cv_number <- sample(x = seq(from = 1, to = SET_CV_NUM), size = nrow(x = iris), replace = TRUE)
# N分割交差検証
> tf_result <- lapply(
X = seq(from = 1, to = SET_CV_NUM),
FUN = function (cv_counter) {
# 学習データ作成
trn_d <- iris %>%
dplyr::filter(cv_number != cv_counter)
trn_x <- trn_d %>%
dplyr::select(-Species) %>%
as.matrix()
trn_y <- trn_d %>%
caret::dummyVars(formula = ~ Species, sep = NULL) %>%
predict(object = ., newdata = trn_d) %>%
as.matrix()
# 初期化
tf_session <- tensorflow::tf$Session()
tf_session$run(fetches = tensorflow::tf$initialize_all_variables())
# パラメータが初期化されているか確認
print(
stringr::str_c(
stringr::str_c("CV:", cv_counter),
stringr::str_c("W:", sum(tf_session$run(W))),
stringr::str_c("b:", sum(tf_session$run(b))),
sep = " "
)
)
# 学習
foreach::times(n = SET_SETP_NUM) %do% {
step_logic <- sample(x = c(FALSE, TRUE), size = nrow(x = trn_x), replace = TRUE, prob = c(0.5, 0.5))
tf_session$run(
fetches = train_step,
feed_dict = tensorflow::dict(x = trn_x[step_logic, ], y_ = trn_y[step_logic, , drop = FALSE])
)
}
# 当てはめ結果
tf_fit_accuracy <- accuracy$eval(feed_dict = tensorflow::dict(x = trn_x, y_ = trn_y), session = tf_session)
fit_confusion_mat <- table(
predict = tf_session$run(fetches = y, feed_dict = tensorflow::dict(x = trn_x, y_ = trn_y)) %>%
apply(MARGIN = 1, FUN = which.max),
true = trn_y %>%
apply(MARGIN = 1, FUN = which.max)
)
# 評価データ作成
tst_d <- iris %>%
dplyr::filter(cv_number == cv_counter)
tst_x <- tst_d %>%
dplyr::select(-Species) %>%
as.matrix()
tst_y <- tst_d %>%
dummyVars(formula = ~ Species, sep = NULL) %>%
predict(object = ., newdata = tst_d) %>%
as.matrix()
# 予測結果
tf_predict_accuracy <- accuracy$eval(feed_dict = tensorflow::dict(x = tst_x, y_ = tst_y), session = tf_session)
predict_confusion_mat <- table(
predict = tf_session$run(fetches = y, feed_dict = tensorflow::dict(x = tst_x, y_ = matrix(data = 0, nrow = nrow(x = tst_x), ncol = SET_DATA_PARAM$CLASS_NUM))) %>%
apply(MARGIN = 1, FUN = which.max),
true = tst_y %>%
apply(MARGIN = 1, FUN = which.max)
)
# 比較用
glmnet_mdl <- glmnet::glmnet(x = trn_x, y = trn_d$Species, family = "multinomial")
glmnet_confusion_mat <- table(
predict = predict(object = glmnet_mdl, newx = tst_x, type = "class", s = 0.01)[, 1],
true = tst_d$Species
)
return(
list(
tf_fit_accuracy = tf_fit_accuracy,
fit_confusion_mat = fit_confusion_mat,
tf_predict_accuracy = tf_predict_accuracy,
predict_confusion_mat = predict_confusion_mat,
glmnet_confusion_mat = glmnet_confusion_mat
)
)
}
)
[1] "CV:1 W:0 b:0"
[1] "CV:2 W:0 b:0"
[1] "CV:3 W:0 b:0"
[1] "CV:4 W:0 b:0"
[1] "CV:5 W:0 b:0"
学習前にパラメータの初期化もされています。
# 当てはめ結果を確認
> sapply(X = tf_result, FUN = "[[", "tf_fit_accuracy")
[1] 0.9487180 0.9920000 0.9663866 0.9842520 0.9732143
# TensorFlow上で評価用に算出したaccuracyとほぼ同じ結果
> sapply(
X = lapply(X = tf_result, FUN = "[[", "fit_confusion_mat"),
FUN = calcAccuracy
)
[1] 0.9487179 0.9920000 0.9663866 0.9842520 0.9732143
# 当てはめ結果の平均
> mean(x = sapply(X = tf_result, FUN = "[[", "tf_fit_accuracy"))
[1] 0.9729141
# 当てはめ結果のconfusion matrix
> lapply(X = tf_result, FUN = "[[", "fit_confusion_mat")
[[1]]
true
predict 1 2 3
1 40 0 0
2 0 31 0
3 0 6 40
[[2]]
true
predict 1 2 3
1 39 0 0
2 0 41 0
3 0 1 44
[[3]]
true
predict 1 2 3
1 42 0 0
2 0 33 0
3 0 4 40
[[4]]
true
predict 1 2 3
1 38 0 0
2 0 43 1
3 0 1 44
[[5]]
true
predict 1 2 3
1 41 0 0
2 0 37 0
3 0 3 31
# 予測結果を確認
> sapply(X = tf_result, FUN = "[[", "tf_predict_accuracy")
[1] 0.9696970 0.9200000 0.9677419 0.9565217 0.9736842
> sapply(
X = lapply(X = tf_result, FUN = "[[", "predict_confusion_mat"),
FUN = calcAccuracy
)
[1] 0.9696970 0.9200000 0.9677419 0.9565217 0.9736842
# 予測結果の平均
> mean(x = sapply(X = tf_result, FUN = "[[", "tf_predict_accuracy"))
[1] 0.957529
# 予測結果のconfusion matrix
> lapply(X = tf_result, FUN = "[[", "predict_confusion_mat")
[[1]]
true
predict 1 2 3
1 10 0 0
2 0 12 0
3 0 1 10
[[2]]
true
predict 1 2 3
1 11 0 0
2 0 6 0
3 0 2 6
[[3]]
true
predict 1 2 3
1 8 0 0
2 0 12 0
3 0 1 10
[[4]]
true
predict 1 2 3
1 12 0 0
2 0 6 1
3 0 0 4
[[5]]
true
predict 1 2 3
1 9 0 0
2 0 9 0
3 0 1 19
# {glmnet}の予測結果
> glmnet_ev <- sapply(
X = lapply(X = tf_result, FUN = "[[", "glmnet_confusion_mat"),
FUN = calcAccuracy
) %>%
print
[1] 1.0000000 0.9200000 0.9677419 0.9130435 0.9736842
> mean(x = glmnet_ev)
[1] 0.9548939
glmnetと差がほとんどなかったです。
まとめ
{tensorflow}を用いて、分析業界の"Hello World"であるirisデータの分類を試しました。以前は{PythonInR}を使ってTensorFlowを呼び出しましたが、RStudio社が公開するパッケージということでこちらの方が安心感があります。今回はとりあえず動かしてみただけですので、モデルの改良などはもう少しお勉強してからにします。
また、現在のところPreview版ですが、RStudioを1系にすると{tensorflow}のオブジェクトがサジェストされるので、とてもとてもオススメです。
参考
- とりいそぎ{PythonInR}でRからTensorFlowを動かしてみた
- AWSのGPUインスタンス上のDockerコンテナでTensorFlow for Rを試す
- glmnetで多クラスのロジスティック回帰
実行環境
> devtools::session_info()
Session info ----------------------------------------------------------------------------------------
setting value
version R version 3.3.1 (2016-06-21)
system x86_64, darwin13.4.0
ui RStudio (1.0.44)
language (EN)
collate ja_JP.UTF-8
tz Asia/Tokyo
date 2016-10-27
Packages --------------------------------------------------------------------------------------------
package * version date source
assertthat 0.1 2013-12-06 CRAN (R 3.3.1)
broom 0.4.1 2016-06-24 CRAN (R 3.3.0)
car 2.1-2 2016-03-25 CRAN (R 3.3.0)
caret * 6.0-70 2016-06-13 CRAN (R 3.3.0)
codetools 0.2-14 2015-07-15 CRAN (R 3.3.1)
colorspace 1.2-6 2015-03-11 CRAN (R 3.3.1)
DBI 0.5 2016-08-11 cran (@0.5)
devtools 1.12.0 2016-06-24 CRAN (R 3.3.0)
digest 0.6.9 2016-01-08 CRAN (R 3.3.0)
dplyr * 0.5.0 2016-06-24 CRAN (R 3.3.1)
foreach * 1.4.3 2015-10-13 CRAN (R 3.3.1)
ggplot2 * 2.1.0 2016-03-01 CRAN (R 3.3.1)
glmnet 2.0-5 2016-03-17 CRAN (R 3.3.0)
gtable 0.2.0 2016-02-26 CRAN (R 3.3.1)
iterators 1.0.8 2015-10-13 CRAN (R 3.3.1)
janeaustenr 0.1.1 2016-06-20 CRAN (R 3.3.0)
lattice * 0.20-33 2015-07-14 CRAN (R 3.3.1)
lme4 1.1-12 2016-04-16 CRAN (R 3.3.0)
magrittr 1.5 2014-11-22 CRAN (R 3.3.1)
MASS 7.3-45 2016-04-21 CRAN (R 3.3.1)
Matrix 1.2-6 2016-05-02 CRAN (R 3.3.1)
MatrixModels 0.4-1 2015-08-22 CRAN (R 3.3.1)
memoise 1.0.0 2016-01-29 CRAN (R 3.3.0)
mgcv 1.8-12 2016-03-03 CRAN (R 3.3.1)
minqa 1.2.4 2014-10-09 CRAN (R 3.3.0)
mnormt 1.5-4 2016-03-09 CRAN (R 3.3.0)
munsell 0.4.3 2016-02-13 CRAN (R 3.3.1)
nlme 3.1-128 2016-05-10 CRAN (R 3.3.1)
nloptr 1.0.4 2014-08-04 CRAN (R 3.3.1)
nnet 7.3-12 2016-02-02 CRAN (R 3.3.1)
pbkrtest 0.4-6 2016-01-27 CRAN (R 3.3.0)
plyr 1.8.4 2016-06-08 CRAN (R 3.3.1)
psych 1.6.6 2016-06-28 CRAN (R 3.3.0)
quantreg 5.26 2016-06-07 CRAN (R 3.3.0)
R6 2.1.3 2016-08-19 cran (@2.1.3)
Rcpp 0.12.7 2016-09-05 cran (@0.12.7)
readr 1.0.0 2016-08-03 cran (@1.0.0)
reshape2 1.4.1 2014-12-06 CRAN (R 3.3.1)
scales 0.4.0 2016-02-26 CRAN (R 3.3.1)
SnowballC 0.5.1 2014-08-09 CRAN (R 3.3.1)
SparseM 1.7 2015-08-15 CRAN (R 3.3.0)
stringi 1.1.1 2016-05-27 CRAN (R 3.3.1)
stringr 1.1.0 2016-08-19 cran (@1.1.0)
tensorflow * 0.3.0 2016-10-25 Github (rstudio/tensorflow@dfe2f1a)
tibble 1.2 2016-08-26 cran (@1.2)
tidyr 0.6.0 2016-08-12 cran (@0.6.0)
tidytext 0.1.1 2016-06-25 CRAN (R 3.3.0)
tokenizers 0.1.4 2016-08-29 CRAN (R 3.3.0)
withr 1.0.2 2016-06-20 CRAN (R 3.3.0)