0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

久しぶりにIRIS分類でニューラルネットワークを学び直す

Last updated at Posted at 2023-10-07

こんにちは。
株式会社クラスアクト インフラストラクチャ事業部の大塚です。

導入?

最近、ChatGPTが流行っていますね。
これ本当に凄いですよね。簡単なプログラミングコードなら一瞬で出してくれますし、お目当ての回答をバシッと提示してくれる(合っているかは判断しないといけないとは思いますが)ので、とりあえずChatGPTに聞いてみて、想定通りの答えならそのまま。想定と違う答えが返ってきたらさらにやり取りをしてみたり、そこでググってみたりするような開発ライフを過ごしています。自分で調べるとお目当ての事を記載してくださっている記事に到達する前に時間がかかったり、そもそも面倒みたいなことがあるので非常に助かっております。

このChatGPTというのは人工知能・DeepLearningにより成立している、具体的にはLLM(Large Language Model)というやつで出来ているみたいです。個人的にはどういう仕組みであれが出来ているのかかなり気になる!
そんなこんなで勉強をぼちぼち開始しているのですが、DLの超基本的な事を理解しておかないといけないなと思い、3,4年ぶりにニューラルネットワークを構築するコードについて、自分の為にまとめました。

参考にしているもの

コードの大部分は以下を参考にしてます。ライブラリ(?)が古くなっているようなものは気付いた範囲で新しくしています。

用意したコード

以下にまとめています。

メモ

IRISデータセットの構造

sklearnにライブラリ(?)としてあるIRISのデータセットを今回は読み込み、使用しているのですが、まずはこのデータセットの中身を見てみたいと思います。iris変数にロードしているので、それをpprintで出力してみます。

import numpy as np
import pprint
from sklearn import datasets

iris = datasets.load_iris()
pprint.pprint(iris)

出力結果は以下となりました。注目すべきはdataとtargetになります。

  • dataに格納されいている情報
    IRISのがくの長さと幅、花弁の長さと幅を示す。(feature_namesで記載されている)
  • targetに格納されている情報
    IRISの種類。0から順番にsetosa,versicolor,virginicaという花の種類を示している。
{'DESCR': '.. _iris_dataset:\n'
          '\n'
          'Iris plants dataset\n'
          '--------------------\n'
          '\n'
          '**Data Set Characteristics:**\n'
          '\n'
          '    :Number of Instances: 150 (50 in each of three classes)\n'
          '    :Number of Attributes: 4 numeric, predictive attributes and the '
          'class\n'
          '    :Attribute Information:\n'
          '        - sepal length in cm\n'
          '        - sepal width in cm\n'
          '        - petal length in cm\n'
          '        - petal width in cm\n'
          '        - class:\n'
          '                - Iris-Setosa\n'
          '                - Iris-Versicolour\n'
          '                - Iris-Virginica\n'
          '                \n'
          '    :Summary Statistics:\n'
          '\n'
          '    ============== ==== ==== ======= ===== ====================\n'
          '                    Min  Max   Mean    SD   Class Correlation\n'
          '    ============== ==== ==== ======= ===== ====================\n'
          '    sepal length:   4.3  7.9   5.84   0.83    0.7826\n'
          '    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n'
          '    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n'
          '    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n'
          '    ============== ==== ==== ======= ===== ====================\n'
          '\n'
          '    :Missing Attribute Values: None\n'
          '    :Class Distribution: 33.3% for each of 3 classes.\n'
          '    :Creator: R.A. Fisher\n'
          '    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n'
          '    :Date: July, 1988\n'
          '\n'
          'The famous Iris database, first used by Sir R.A. Fisher. The '
          'dataset is taken\n'
          "from Fisher's paper. Note that it's the same as in R, but not as in "
          'the UCI\n'
          'Machine Learning Repository, which has two wrong data points.\n'
          '\n'
          'This is perhaps the best known database to be found in the\n'
          "pattern recognition literature.  Fisher's paper is a classic in the "
          'field and\n'
          'is referenced frequently to this day.  (See Duda & Hart, for '
          'example.)  The\n'
          'data set contains 3 classes of 50 instances each, where each class '
          'refers to a\n'
          'type of iris plant.  One class is linearly separable from the other '
          '2; the\n'
          'latter are NOT linearly separable from each other.\n'
          '\n'
          '.. topic:: References\n'
          '\n'
          '   - Fisher, R.A. "The use of multiple measurements in taxonomic '
          'problems"\n'
          '     Annual Eugenics, 7, Part II, 179-188 (1936); also in '
          '"Contributions to\n'
          '     Mathematical Statistics" (John Wiley, NY, 1950).\n'
          '   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and '
          'Scene Analysis.\n'
          '     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page '
          '218.\n'
          '   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New '
          'System\n'
          '     Structure and Classification Rule for Recognition in Partially '
          'Exposed\n'
          '     Environments".  IEEE Transactions on Pattern Analysis and '
          'Machine\n'
          '     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n'
          '   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE '
          'Transactions\n'
          '     on Information Theory, May 1972, 431-433.\n'
          '   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s '
          'AUTOCLASS II\n'
          '     conceptual clustering system finds 3 classes in the data.\n'
          '   - Many, many more ...',
 'data': array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]]),
 'data_module': 'sklearn.datasets.data',
 'feature_names': ['sepal length (cm)',
                   'sepal width (cm)',
                   'petal length (cm)',
                   'petal width (cm)'],
 'filename': 'iris.csv',
 'frame': None,
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10')}

これらをわかりやすくリスト化しているGitHubがありました。これをみることで使おうとしているデータセットについて、イメージが付きやすくなると思います。

前処理

from sklearn import preprocessing
import tensorflow as tf
from sklearn.model_selection import train_test_split

scaler = preprocessing.StandardScaler()
scaler.fit(iris.data)
x = scaler.transform(iris.data)
t = tf.keras.utils.to_categorical(iris.target, num_classes=3)

x_train, x_test, t_train, t_test = train_test_split(x, t, train_size=0.75)
  • scaler.fit
    標準化のための数値を算出する
    標準化:平均を0にし、標準偏差を1にする操作
  • scaler.transform
    scaler.fitで算出した数値を使ってデータ(今回はiris.data)を標準化する
  • scaler.fit
    この関数は、整数のクラスラベルをバイナリのカテゴリカルデータに変換するために使用。具体的には、クラスの数を表す整数を、それぞれのクラスに対応する位置に1を持つバイナリの配列に変換する。[0, 1, 2, 1]を[[1. 0. 0.][0. 1. 0.][0. 0. 1.][0. 1. 0.]]に変換する。
  • train_test_split(x, t, train_size=0.75)
    機械学習やデータ分析において、データセットをトレーニングデータとテストデータに分割するために使用される一般的な関数または手法。トレーニングデータが全体の大部分(通常70%から80%)を占め、テストデータが残りの部分を占めることが一般的。このように分割することにより、モデルの過学習を防ぎ、一般化性能を評価するためのデータを確保できる。今回はx_trainとt_trainがセット、x_testとt_testがセットとなり、それぞれトレーニングデータセットとテストデータとなる。

モデルの作成

from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequential()
model.add(Dense(32,input_dim=4))
model.add(Activation("relu"))
model.add(Dense(32))
model.add(Activation("relu"))
model.add(Dense(3))
model.add(Activation("softmax"))
model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=["accuracy"])

print(model.summary())
  • model = Sequential()
    シーケンシャルモデルをmodelという変数に格納している。
    シーケンシャルモデル:シーケンシャルモデルは、層を一直線に積み重ねる簡単なモデルの構築を支援するために設計されている。典型的には、前から後にデータが一方向に流れるフィードフォワードニューラルネットワーク(Feedforward Neural Network)を表現するために使用される。

  • model.add(Dense(32,input_dim=4))
    model.add(Activation("relu"))

    • シーケンシャルモデルに層を追加する方法を示している。具体的には全結合層(Dense層)を32ユニット(≒層にあるニューロンの数)と4つの入力特徴量を持つものとして追加し、その後にReLU活性化関数を適用している。4つとなっているのはIRISの『がくの長さと幅、花弁の長さと幅』から。このコードは、ニューラルネットワークのモデルに2つの層を追加することを意味する。
    • このことを図に起こしたのが以下(と思っている。)。ニューロン内部の状態も併せて図示してみた。重み・バイアス(これらをパラメータと呼ぶ)については、学習の段階で計算されていく。これが機械学習・DeepLearning
      DL.drawio.png
  • model.add(Dense(32))
    model.add(Activation("relu"))

    • 中間層(隠れ層ともいうんですか?)を追加しているコード。このコードを図示したのが以下(と思っている。)
      DL-ページ2.drawio.png
  • model.add(Dense(3))
    model.add(Activation("softmax"))

    • ソフトマックス活性化関数を使用した出力層を追加することを意味する。出力層のニューロンを3つにしているのはIRISの分類で答えとなる種類が3つであるから。Softmax関数が出力時に使用されると各クラスに属する確率を出力する。最終的に、確率が最も高いクラスが予測結果として選択されるようになる。
      DL-ページ3.drawio (1).png

モデルの学習と学習の推移

import matplotlib.pyplot as plt

history = model.fit(x_train,t_train,epochs=30,batch_size=8)

hist_loss = history.history['loss']
hist_acc = history.history['accuracy']

plt.plot(np.arange(len(hist_loss)), hist_loss, label='loss')
plt.plot(np.arange(len(hist_acc)), hist_acc, label='acc')
plt.legend()
plt.show()
  • history = model.fit(x_train, t_train, epochs=30, batch_size=8)
    モデルをトレーニングするもの。x_train はトレーニングデータ、t_train はトレーニングデータのラベル(またはターゲット)。epochs パラメータはトレーニングエポックの数を指定し、batch_size パラメータはミニバッチのサイズを指定します。model.fit メソッドは、モデルのトレーニングを開始し、トレーニング履歴情報を history 変数に保存している。

  • hist_loss = history.history['loss']
    hist_acc = history.history['accuracy']
    トレーニング履歴から損失(loss)と正確度(accuracy)の情報を抽出し、それぞれ hist_loss と hist_acc 変数に格納。history.history ディクショナリには、各エポックの損失と正確度が保存されている。

    • 損失:損失は、モデルがトレーニングデータの予測と実際の目標との誤差を示す指標です。モデルがどれだけ予測が実際の値からずれているかを表す。トレーニング中、モデルは損失を最小化するように重みとバイアス(パラメータ)を調整する。損失が減少することは、モデルがデータに適応し、良い予測を行っていることを示す。
    • 正確度:正確度は、モデルのクラス分類タスクにおける性能を示す指標で、正しく分類されたサンプルの割合を表す。正確度は通常、分類タスクの評価に使用される。例えば、2つ以上のクラスがある多クラス分類タスクの場合、モデルが正確にどれだけのサンプルを正しいクラスに分類できるかを評価する。

評価と予測

loss, accuracy = model.evaluate(x_test,t_test)
print("以下はそれぞれテストデータにおける損失と正確度を示す")
print(loss,accuracy)

print("以下の数値はそれぞれsetosaの確率、versicolorの確率、virginicaの確率を示す")
model.predict(x_test)

出力結果

2/2 [==============================] - 0s 22ms/step - loss: 0.4568 - accuracy: 0.8158
以下はそれぞれ損失と正確度を示す
0.45681264996528625 0.8157894611358643
以下の数値はそれぞれsetosaの確率、versicolorの確率、virginicaの確率を示す
2/2 [==============================] - 0s 8ms/step
array([[0.02868483, 0.57167155, 0.3996436 ],
       [0.02295968, 0.16441241, 0.81262785],
       [0.03702914, 0.23847784, 0.7244929 ],
       [0.04994134, 0.47118318, 0.47887537],
       [0.12778834, 0.7555497 , 0.11666191],
       [0.04428819, 0.81123936, 0.14447248],
       [0.98523736, 0.00786163, 0.006901  ],
       [0.00575337, 0.05832081, 0.9359258 ],
       [0.04262898, 0.4080595 , 0.5493116 ],
       [0.9918903 , 0.00415106, 0.00395862],
       [0.08286932, 0.4849872 , 0.43214336],
       [0.9388504 , 0.0391735 , 0.02197619],
       [0.02935005, 0.7881833 , 0.18246664],
       [0.00241864, 0.03151884, 0.9660624 ],
       [0.00200983, 0.02525597, 0.97273415],
       [0.17232579, 0.647587  , 0.18008718],
       [0.16127275, 0.6568065 , 0.18192077],
       [0.10670465, 0.6254166 , 0.26787877],
       [0.04949111, 0.28522646, 0.6652825 ],
       [0.09261522, 0.7283675 , 0.17901723],
       [0.00367155, 0.03632863, 0.95999986],
       [0.97927374, 0.0119858 , 0.00874055],
       [0.03635976, 0.4148069 , 0.5488334 ],
       [0.06169634, 0.34298813, 0.59531546],
       [0.00834672, 0.09334946, 0.89830387],
       [0.08782964, 0.80527484, 0.10689548],
       [0.05395289, 0.73569304, 0.2103541 ],
       [0.0406357 , 0.3449983 , 0.61436605],
       [0.11929395, 0.70258194, 0.17812411],
       [0.9467625 , 0.03319071, 0.02004671],
       [0.17375658, 0.65617955, 0.17006378],
       [0.02197714, 0.2866666 , 0.6913562 ],
       [0.00194179, 0.02272333, 0.97533494],
       [0.06516223, 0.34572625, 0.5891115 ],
       [0.03315317, 0.47795698, 0.48888987],
       [0.11591238, 0.7584487 , 0.12563896],
       [0.9777915 , 0.01275867, 0.00944987],
       [0.96137637, 0.02342943, 0.0151942 ]], dtype=float32)

モデルの保存・呼び出し

from keras.models import load_model

model.save("model.h5")
from keras.models import load_model

load_model("model.h5")

モデルを保存し、それを呼び出すことでそのモデルを使って分類などを出来るようになります。
ザックリイメージに落とし込むと以下になるでしょうか。
DL-ページ4.drawio (1).png

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?