19
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

{tensorflow}をirisデータで試してみる

Posted at

概要

 少し前に{tensorflow}というRからTensorFlowを使うパッケージがRStudio社から公開されたので、みんな大好きirisデータの分類をMNISTの例を参考に試してみました。

事前準備

 実行環境によっては以前にインストールしたバージョンを削除して再度インストールする必要があります(自分はprotobufも再インストールしました)。

NOTE: If you are upgrading from a previous installation of TensorFlow < 0.7.1, you should uninstall the previous TensorFlow and protobuf using pip uninstall first to make sure you get a clean installation of the updated protobuf dependency.
 https://www.tensorflow.org/versions/r0.11/get_started/os_setup.html

必要ライブラリのインストール
# 削除したprotobufを再インストール
$ pip install protobuf

# OS XでCPU利用の場合
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.11.0rc1-py2-none-any.whl
$ sudo pip install --upgrade $TF_BINARY_URL

# {tensorflow}のインストール時に{tensorflow}で利用するPythonを指定するため、パスを確認しておく
$ which python
/usr/local/bin/python
export PYTHON_PATH=`which python`
R事前準備
# {tensorflow}で利用するPythonのパスを環境変数に設定してインストールする
Sys.setenv(TENSORFLOW_PYTHON = "/usr/local/bin/python")
devtools::install_github("rstudio/tensorflow")

 インストール後にパッケージを呼び出してエラーが起きなければOK(エラーが起きる場合は使用するライブラリの再インストールや、上記のPythonパスを確認するなど試す)。なお、エラー後に再度パッケージを読み込むとエラーメッセージは表示されないが、使用できないので注意。

定義・設定

 処理で利用するライブラリの読み込みや定数・関数。加えて今回はTensorFlowのモデルの定義をここで行う。

定数定義部
library(tensorflow)
library(dplyr)
library(foreach)
library(caret)

SET_CV_NUM <- 5
SET_DATA_PARAM <- list(
  CLASS_NUM = 3L, FEATURE_NUM = 4L
)
SET_SETP_NUM <- 3000
関数定義部
# confusion matrixからAccuracyを計算
calcAccuracy <- function(confusion_mat) {
  return(sum(diag(x = confusion_mat)) / sum(confusion_mat))
}
モデル定義部
# MNIST For ML Beginnersの例を参考
W <- tensorflow::tf$Variable(
  initial_value = tensorflow::tf$zeros(shape = tensorflow::shape(SET_DATA_PARAM$FEATURE_NUM, SET_DATA_PARAM$CLASS_NUM))
)
x <- tensorflow::tf$placeholder(
  dtype = tensorflow::tf$float32,
  shape = tensorflow::shape(NULL, SET_DATA_PARAM$FEATURE_NUM)
)
b <- tensorflow::tf$Variable(
  initial_value = tensorflow::tf$zeros(shape = tensorflow::shape(SET_DATA_PARAM$CLASS_NUM))
)

y <- tensorflow::tf$nn$softmax(logits = tensorflow::tf$matmul(a = x, b = W) + b)
y_ <- tensorflow::tf$placeholder(dtype = tensorflow::tf$float32, shape = tensorflow::shape(NULL, SET_DATA_PARAM$CLASS_NUM))


# 損失関数とオプティマイザーの設定
cross_entropy <- tensorflow::tf$reduce_mean(input_tensor = - tensorflow::tf$reduce_sum(input_tensor = y_ * tensorflow::tf$log(x = y), reduction_indices = 1L))
optimizer <- tensorflow::tf$train$GradientDescentOptimizer(learning_rate = 0.5)
train_step <- optimizer$minimize(loss = cross_entropy)


# 評価用
correct_prediction <- tensorflow::tf$equal(
  x = tensorflow::tf$argmax(input = y, dimension = 1L), y = tensorflow::tf$argmax(input = y_, dimension = 1L)
)
accuracy <- tensorflow::tf$reduce_mean(
  input_tensor = tensorflow::tf$cast(x = correct_prediction, dtype = tensorflow::tf$float32)
)

実行部

学習とテスト
# irisデータ
data(iris)
> iris %>% 
  head(n = 10)
   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1           5.1         3.5          1.4         0.2  setosa
2           4.9         3.0          1.4         0.2  setosa
3           4.7         3.2          1.3         0.2  setosa
4           4.6         3.1          1.5         0.2  setosa
5           5.0         3.6          1.4         0.2  setosa
6           5.4         3.9          1.7         0.4  setosa
7           4.6         3.4          1.4         0.3  setosa
8           5.0         3.4          1.5         0.2  setosa
9           4.4         2.9          1.4         0.2  setosa
10          4.9         3.1          1.5         0.1  setosa

> iris %>% 
   dplyr::group_by(Species) %>% 
   dplyr::summarise_all(.fun = mean)
# A tibble: 3 × 5
     Species Sepal.Length Sepal.Width Petal.Length Petal.Width
      <fctr>        <dbl>       <dbl>        <dbl>       <dbl>
1     setosa        5.006       3.428        1.462       0.246
2 versicolor        5.936       2.770        4.260       1.326
3  virginica        6.588       2.974        5.552       2.026


set.seed(seed = 71)
cv_number <- sample(x = seq(from = 1, to = SET_CV_NUM), size = nrow(x = iris), replace = TRUE)


# N分割交差検証
> tf_result <- lapply(
  X = seq(from = 1, to = SET_CV_NUM),
  FUN = function (cv_counter) {

    # 学習データ作成
    trn_d <- iris %>% 
      dplyr::filter(cv_number != cv_counter)
    trn_x <- trn_d %>% 
      dplyr::select(-Species) %>% 
      as.matrix()
    trn_y <- trn_d %>% 
      caret::dummyVars(formula = ~ Species, sep = NULL) %>% 
      predict(object = ., newdata = trn_d) %>% 
      as.matrix()
    

    # 初期化
    tf_session <- tensorflow::tf$Session()
    tf_session$run(fetches = tensorflow::tf$initialize_all_variables())

    # パラメータが初期化されているか確認
    print(
      stringr::str_c(
        stringr::str_c("CV:", cv_counter),
        stringr::str_c("W:", sum(tf_session$run(W))),
        stringr::str_c("b:", sum(tf_session$run(b))),
        sep = " "
      )
    )

    # 学習
    foreach::times(n = SET_SETP_NUM) %do% {
      step_logic <- sample(x = c(FALSE, TRUE), size = nrow(x = trn_x), replace = TRUE, prob = c(0.5, 0.5))
      tf_session$run(
        fetches = train_step,
        feed_dict = tensorflow::dict(x = trn_x[step_logic, ], y_ = trn_y[step_logic, , drop = FALSE])
      )
    }

    # 当てはめ結果
    tf_fit_accuracy <- accuracy$eval(feed_dict = tensorflow::dict(x = trn_x, y_ = trn_y), session = tf_session)
    fit_confusion_mat <- table(
      predict = tf_session$run(fetches = y, feed_dict = tensorflow::dict(x = trn_x, y_ = trn_y)) %>%
        apply(MARGIN = 1, FUN = which.max),
      true = trn_y %>% 
        apply(MARGIN = 1, FUN = which.max)
    )

        
    # 評価データ作成
    tst_d <- iris %>% 
      dplyr::filter(cv_number == cv_counter)
    tst_x <- tst_d %>% 
      dplyr::select(-Species) %>% 
      as.matrix()
    tst_y <- tst_d %>% 
      dummyVars(formula = ~ Species, sep = NULL) %>% 
      predict(object = ., newdata = tst_d) %>% 
      as.matrix()
    
    # 予測結果
    tf_predict_accuracy <- accuracy$eval(feed_dict = tensorflow::dict(x = tst_x, y_ = tst_y), session = tf_session)
    predict_confusion_mat <- table(
      predict = tf_session$run(fetches = y, feed_dict = tensorflow::dict(x = tst_x, y_ = matrix(data = 0, nrow = nrow(x = tst_x), ncol = SET_DATA_PARAM$CLASS_NUM))) %>%
        apply(MARGIN = 1, FUN = which.max),
      true = tst_y %>% 
        apply(MARGIN = 1, FUN = which.max)
    )


    # 比較用
    glmnet_mdl <- glmnet::glmnet(x = trn_x, y = trn_d$Species, family = "multinomial")
    glmnet_confusion_mat <- table(
      predict = predict(object = glmnet_mdl, newx = tst_x, type = "class", s = 0.01)[, 1],
      true = tst_d$Species
    )
    
    return(
      list(
        tf_fit_accuracy = tf_fit_accuracy,
        fit_confusion_mat = fit_confusion_mat,
        tf_predict_accuracy = tf_predict_accuracy,
        predict_confusion_mat = predict_confusion_mat,
        glmnet_confusion_mat = glmnet_confusion_mat
      )
    )
  }
)
[1] "CV:1 W:0 b:0"
[1] "CV:2 W:0 b:0"
[1] "CV:3 W:0 b:0"
[1] "CV:4 W:0 b:0"
[1] "CV:5 W:0 b:0"

 学習前にパラメータの初期化もされています。

評価結果
# 当てはめ結果を確認
> sapply(X = tf_result, FUN = "[[", "tf_fit_accuracy")
[1] 0.9487180 0.9920000 0.9663866 0.9842520 0.9732143
# TensorFlow上で評価用に算出したaccuracyとほぼ同じ結果
> sapply(
  X = lapply(X = tf_result, FUN = "[[", "fit_confusion_mat"),
  FUN = calcAccuracy
)
[1] 0.9487179 0.9920000 0.9663866 0.9842520 0.9732143

# 当てはめ結果の平均
> mean(x = sapply(X = tf_result, FUN = "[[", "tf_fit_accuracy"))
[1] 0.9729141

# 当てはめ結果のconfusion matrix
> lapply(X = tf_result, FUN = "[[", "fit_confusion_mat")
[[1]]
       true
predict  1  2  3
      1 40  0  0
      2  0 31  0
      3  0  6 40

[[2]]
       true
predict  1  2  3
      1 39  0  0
      2  0 41  0
      3  0  1 44

[[3]]
       true
predict  1  2  3
      1 42  0  0
      2  0 33  0
      3  0  4 40

[[4]]
       true
predict  1  2  3
      1 38  0  0
      2  0 43  1
      3  0  1 44

[[5]]
       true
predict  1  2  3
      1 41  0  0
      2  0 37  0
      3  0  3 31


# 予測結果を確認
> sapply(X = tf_result, FUN = "[[", "tf_predict_accuracy")
[1] 0.9696970 0.9200000 0.9677419 0.9565217 0.9736842
> sapply(
  X = lapply(X = tf_result, FUN = "[[", "predict_confusion_mat"),
  FUN = calcAccuracy
)
[1] 0.9696970 0.9200000 0.9677419 0.9565217 0.9736842

# 予測結果の平均
> mean(x = sapply(X = tf_result, FUN = "[[", "tf_predict_accuracy"))
[1] 0.957529

# 予測結果のconfusion matrix
> lapply(X = tf_result, FUN = "[[", "predict_confusion_mat")
[[1]]
       true
predict  1  2  3
      1 10  0  0
      2  0 12  0
      3  0  1 10

[[2]]
       true
predict  1  2  3
      1 11  0  0
      2  0  6  0
      3  0  2  6

[[3]]
       true
predict  1  2  3
      1  8  0  0
      2  0 12  0
      3  0  1 10

[[4]]
       true
predict  1  2  3
      1 12  0  0
      2  0  6  1
      3  0  0  4

[[5]]
       true
predict  1  2  3
      1  9  0  0
      2  0  9  0
      3  0  1 19

# {glmnet}の予測結果
> glmnet_ev <- sapply(
  X = lapply(X = tf_result, FUN = "[[", "glmnet_confusion_mat"),
  FUN = calcAccuracy
) %>% 
  print
[1] 1.0000000 0.9200000 0.9677419 0.9130435 0.9736842
> mean(x = glmnet_ev)
[1] 0.9548939

glmnetと差がほとんどなかったです。

まとめ

 {tensorflow}を用いて、分析業界の"Hello World"であるirisデータの分類を試しました。以前は{PythonInR}を使ってTensorFlowを呼び出しましたが、RStudio社が公開するパッケージということでこちらの方が安心感があります。今回はとりあえず動かしてみただけですので、モデルの改良などはもう少しお勉強してからにします。
 また、現在のところPreview版ですが、RStudioを1系にすると{tensorflow}のオブジェクトがサジェストされるので、とてもとてもオススメです。

参考

実行環境

実行環境
> devtools::session_info()
Session info ----------------------------------------------------------------------------------------
 setting  value                       
 version  R version 3.3.1 (2016-06-21)
 system   x86_64, darwin13.4.0        
 ui       RStudio (1.0.44)            
 language (EN)                        
 collate  ja_JP.UTF-8                 
 tz       Asia/Tokyo                  
 date     2016-10-27                  

Packages --------------------------------------------------------------------------------------------
 package      * version date       source                             
 assertthat     0.1     2013-12-06 CRAN (R 3.3.1)                     
 broom          0.4.1   2016-06-24 CRAN (R 3.3.0)                     
 car            2.1-2   2016-03-25 CRAN (R 3.3.0)                     
 caret        * 6.0-70  2016-06-13 CRAN (R 3.3.0)                     
 codetools      0.2-14  2015-07-15 CRAN (R 3.3.1)                     
 colorspace     1.2-6   2015-03-11 CRAN (R 3.3.1)                     
 DBI            0.5     2016-08-11 cran (@0.5)                        
 devtools       1.12.0  2016-06-24 CRAN (R 3.3.0)                     
 digest         0.6.9   2016-01-08 CRAN (R 3.3.0)                     
 dplyr        * 0.5.0   2016-06-24 CRAN (R 3.3.1)                     
 foreach      * 1.4.3   2015-10-13 CRAN (R 3.3.1)                     
 ggplot2      * 2.1.0   2016-03-01 CRAN (R 3.3.1)                     
 glmnet         2.0-5   2016-03-17 CRAN (R 3.3.0)                     
 gtable         0.2.0   2016-02-26 CRAN (R 3.3.1)                     
 iterators      1.0.8   2015-10-13 CRAN (R 3.3.1)                     
 janeaustenr    0.1.1   2016-06-20 CRAN (R 3.3.0)                     
 lattice      * 0.20-33 2015-07-14 CRAN (R 3.3.1)                     
 lme4           1.1-12  2016-04-16 CRAN (R 3.3.0)                     
 magrittr       1.5     2014-11-22 CRAN (R 3.3.1)                     
 MASS           7.3-45  2016-04-21 CRAN (R 3.3.1)                     
 Matrix         1.2-6   2016-05-02 CRAN (R 3.3.1)                     
 MatrixModels   0.4-1   2015-08-22 CRAN (R 3.3.1)                     
 memoise        1.0.0   2016-01-29 CRAN (R 3.3.0)                     
 mgcv           1.8-12  2016-03-03 CRAN (R 3.3.1)                     
 minqa          1.2.4   2014-10-09 CRAN (R 3.3.0)                     
 mnormt         1.5-4   2016-03-09 CRAN (R 3.3.0)                     
 munsell        0.4.3   2016-02-13 CRAN (R 3.3.1)                     
 nlme           3.1-128 2016-05-10 CRAN (R 3.3.1)                     
 nloptr         1.0.4   2014-08-04 CRAN (R 3.3.1)                     
 nnet           7.3-12  2016-02-02 CRAN (R 3.3.1)                     
 pbkrtest       0.4-6   2016-01-27 CRAN (R 3.3.0)                     
 plyr           1.8.4   2016-06-08 CRAN (R 3.3.1)                     
 psych          1.6.6   2016-06-28 CRAN (R 3.3.0)                     
 quantreg       5.26    2016-06-07 CRAN (R 3.3.0)                     
 R6             2.1.3   2016-08-19 cran (@2.1.3)                      
 Rcpp           0.12.7  2016-09-05 cran (@0.12.7)                     
 readr          1.0.0   2016-08-03 cran (@1.0.0)                      
 reshape2       1.4.1   2014-12-06 CRAN (R 3.3.1)                     
 scales         0.4.0   2016-02-26 CRAN (R 3.3.1)                     
 SnowballC      0.5.1   2014-08-09 CRAN (R 3.3.1)                     
 SparseM        1.7     2015-08-15 CRAN (R 3.3.0)                     
 stringi        1.1.1   2016-05-27 CRAN (R 3.3.1)                     
 stringr        1.1.0   2016-08-19 cran (@1.1.0)                      
 tensorflow   * 0.3.0   2016-10-25 Github (rstudio/tensorflow@dfe2f1a)
 tibble         1.2     2016-08-26 cran (@1.2)                        
 tidyr          0.6.0   2016-08-12 cran (@0.6.0)                      
 tidytext       0.1.1   2016-06-25 CRAN (R 3.3.0)                     
 tokenizers     0.1.4   2016-08-29 CRAN (R 3.3.0)                     
 withr          1.0.2   2016-06-20 CRAN (R 3.3.0)      
19
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?