LoginSignup
6
4

More than 5 years have passed since last update.

R Studio で Neural Network を試してみた

Last updated at Posted at 2017-06-21

irisデータ・セットを用いてNeural NetworkをR Studioで試してみました。

手順1: R StudioにNeural Networkを入れます。

install.packages('nnet’)  #既にインストール済みなら必要ありません。
library("nnet")

手順2: irisのデータ取り込みを行います。

data(iris)

irisのデータ・セットは以下のようになっています。

    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
1            5.1         3.5          1.4         0.2     setosa
2            4.9         3.0          1.4         0.2     setosa
3            4.7         3.2          1.3         0.2     setosa
...
50           5.0         3.3          1.4         0.2     setosa
51           7.0         3.2          4.7         1.4 versicolor
52           6.4         3.2          4.5         1.5 versicolor
...
100          5.7         2.8          4.1         1.3 versicolor
101          6.3         3.3          6.0         2.5  virginica
102          5.8         2.7          5.1         1.9  virginica
...
149          6.2         3.4          5.4         2.3  virginica
150          5.9         3.0          5.1         1.8  virginica

手順3: irisのデータ・セットはSpecisが3種類合計150のデータからなるため、

sを50回,cを50回,vを50回繰り返したクラスを作成します。
s:setosa,c:versicolor,v:virginicaにそれぞれに対応するとします。

targets <- class.ind( c(rep("s", 50), rep("c", 50), rep("v", 50)) )

手順4: irisデータに対応するindexを各種25個ずつランダムで指定します。

index <- c(sample(1:50, 25), sample(51:100, 25), sample(101:150, 25) )

手順5: 訓練データを作成します。

上記ランダムで取得したindexに対応するirisデータと、それに対応するclass(ここではtargets)を指定します。各引数の説明は以下です。

neural_network <- nnet(iris[index, 1:4], targets[index, ], size = 2, rang = 0.1, decay = 5e-4, maxit = 200)
引数 意味 / 内容
size hidden layer(隠れ層)のユニット数です。 skip層ユニットがある場合はゼロにすることができます。
rang 双曲線正接で考えた際の範囲を指定します。rang * max(x) が約1になるように選択する必要があります。
decay 重荷衰退です。1より小さい値をかける手法で、パラメータ値が大きくなりすぎることを防ぎます。
maxit 反復回数です。デフォルトで100が指定されています。

スクリーンショット 2017-06-27 11.53.01.png

手順6: 先に求めたニュラールネットワークの各重みを確認します。

summary(neural_network)

スクリーンショット 2017-06-27 11.53.10.png

手順7: ニュラールネットワーク図の出力方法は以下です

library("neuralnet")
network = neuralnet(c + s + v~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, cbind(iris[index, 1:4],targets[index, ]), hidden=2, err.fct = "ce", linear.output = FALSE)
plot(network)

スクリーンショット 2017-06-27 11.52.15.png

手順8: irisの訓練データに用いなかった残りをテストデータとしてクラス分け推論を実行します。

predict(neural_network, iris[-index, 1:4])

結果は以下のようになりました。

                 c              s              v
2   0.014587638374 0.986320887172 0.003540863552
3   0.013842952120 0.986980663902 0.003533837310
5   0.013429685128 0.987347646405 0.003529781843
6   0.013616361660 0.987181801117 0.003531628184
8   0.013884505592 0.986943797936 0.003534238708
12  0.014346863668 0.986534001705 0.003538629740
14  0.013740289224 0.987071771890 0.003532840688
18  0.013571846781 0.987221337141 0.003531190096
20  0.013419526514 0.987356675049 0.003529680671
22  0.013587764412 0.987207199030 0.003531346904
23  0.013159742434 0.987587691545 0.003527068426
25  0.016579416783 0.984565001799 0.003558108052
26  0.015879794665 0.985180380128 0.003552285738
27  0.014593021505 0.986316124670 0.003540913094
29  0.013584621063 0.987209990894 0.003531315952
31  0.015578427580 0.985445909724 0.003549702610
34  0.013103380168 0.987637845510 0.003526495223
35  0.014764288745 0.986164653128 0.003542480337
36  0.013529539848 0.987258918791 0.003530772472
40  0.013839363845 0.986983847681 0.003533802594
41  0.013443277039 0.987335566949 0.003529917094
42  0.021045497985 0.980667321842 0.003590574262
43  0.013986725395 0.986853134984 0.003535221303
45  0.014670369348 0.986247704984 0.003541623040
48  0.014216190164 0.986649745412 0.003537402486
51  0.993054699489 0.012442891285 0.005005659338
52  0.992764920208 0.012488568564 0.005197038515
53  0.992849966433 0.012244466899 0.005250278005
57  0.990796780576 0.012261989432 0.006779637396
59  0.993073244123 0.012347218818 0.005035654082
61  0.992741363303 0.012747570085 0.005095036458
62  0.991983743715 0.012563277795 0.005730808274
63  0.993042724394 0.012475355063 0.004999749339
64  0.991121245906 0.012177124507 0.006587350042
66  0.992917740218 0.012666904191 0.005004588386
67  0.972849055502 0.011926431592 0.021105079646
71  0.088920162560 0.010389535001 0.910885730976
74  0.992893078572 0.012244047203 0.005218228472
76  0.992961366164 0.012546725370 0.005026903832
78  0.964440153494 0.011746952365 0.028289671297
79  0.989107518919 0.012175283237 0.008114982699
80  0.991325503133 0.015369427691 0.004947472114
81  0.992814563211 0.012768591081 0.005033322456
83  0.992694059108 0.013022632390 0.005006906140
84  0.044353090716 0.010170181604 0.957076805767
87  0.992847375999 0.012317403598 0.005217192664
88  0.992625380038 0.012220120065 0.005430509545
95  0.992158422545 0.012384759278 0.005694616803
97  0.992672903218 0.012568309358 0.005226776347
100 0.992679970346 0.012616141565 0.005199334577
101 0.006509507053 0.009726265031 0.994209228320
102 0.006786440272 0.009741294695 0.993947955604
104 0.007749219997 0.009766755453 0.993052832838
105 0.006526682512 0.009727391648 0.994192924555
109 0.007112374300 0.009745678722 0.993648835259
113 0.007168842460 0.009751206635 0.993593440776
114 0.006557517030 0.009732095287 0.994161836891
115 0.006508396787 0.009729670696 0.994207956651
116 0.006577889798 0.009735961590 0.994140766277
117 0.013939199201 0.009899418860 0.987179749542
119 0.006510884130 0.009725490271 0.994208499842
120 0.037666104787 0.010126470437 0.963827334610
125 0.007001487859 0.009745823642 0.993749483085
128 0.049851102614 0.010218958098 0.951402597165
129 0.006554494240 0.009728780831 0.994166807172
130 0.913381205096 0.011466140312 0.072524288959
131 0.008753826131 0.009791114951 0.992113229372
132 0.152394413644 0.010472473335 0.844782748178
134 0.929789392381 0.011540816939 0.058000278580
137 0.006515457161 0.009729732570 0.994201521289
143 0.006786440272 0.009741294695 0.993947955604
144 0.006531998041 0.009727760950 0.994187863990
145 0.006513985809 0.009728212291 0.994203871006
146 0.006692045729 0.009740138934 0.994034402766
150 0.010454134773 0.009844735492 0.990493554835

手順9: 出力結果を表で確認します。

table(max.col(targets[-index, ]), max.col(predict(neural_network, iris[-index, 1:4])))

スクリーンショット 2017-06-27 11.53.32.png

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4