0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Keras的"hello, world"

Last updated at Posted at 2021-02-21

はじめに

KerasはPythonで動作する高水準なニューラルネットワークのライブラリである。バックエンドとしてTensorFlowやTheanoを使用している。TensorFlow1の頃はKerasを別にインストールする必要があったが、TensorFlow2ではKerasが内臓されるようになった。つまりTensorFlow2をインストールすれば、Kerasを使う準備ができたことになる。

hello, world

ブライアン・カーニハンとデニス・リッチーの著書「プログラミング言語C」では、C言語の最初の例題として「hello, world」を表示するプログラムを紹介している。これ以降、さまざまなプログラミング言語の教科書の冒頭で「hello, xxxx」を例題として使うことが通例となっている。では、Kerasではどうなるか?

C言語の"hello, world"では、次のことが学習できる。

  • printf関数で画面に文字列を表示する。
  • printf関数を使用するためにヘッダファイルをインクルードする。
  • プログラムはmain関数で始まる。

Kerasを使った最初のプログラムはどうあるべきか?

  • Kerasを使用するためにライブラリをインポートする。
  • モデルを作成する。
    • 最初に作るプログラムのモデルは単純なモデルがよい。
    • モデルの各階層で活性化関数を使用する。
  • モデルをコンパイルする。
    • コンパイル時に損失関数や最適化アルゴリズムを使用する。
  • トレーニングデータを学習させる。
  • テストデータで学習の内容を確認する。
  • 可能であれば結果をグラフで表示する。

ということで、私の考えるKerasの"hello, world"プログラムは、一次関数を学習させるプログラムが適しているのではないかと思っている(hello, worldの表示はないが)。

一次関数を学習してみる

一次関数だと分かっているので、入力データ1つに対して1つの値を出力する1階層の単純なモデルを作成する。活性化関数にはLinear(入力値をそのまま返す)を使 使用する。損失関数はMSE(平均二乗誤差)、最適化アルゴリズムはSGD(確率的勾配降下法)を使用する。

拙筆の記事「scikit-leanを使ってみる」の一次関数の例をKearasで学習させるプログラムは下記のようになる。

import matplotlib.pyplot as plt
import numpy as np
from tensorflow import keras
from sklearn.model_selection import train_test_split

# Y = 2X + 3を求める関数とする。
x = np.linspace(0, 10, num=50)
y = 2 * x + 3

# トレーニングデータと検証データに分割
x_train, x_test, y_train, y_test = train_test_split(x, y, shuffle=False)

# モデルを作成
# 1階層で活性化関数としてLinearを使用
model = keras.models.Sequential([
    keras.layers.Dense(1, input_shape=(1,), activation="linear"),
])

# オプティマイザはSGD、損失関数はMSEを使用
model.compile(optimizer='sgd', loss='mse')

# トレーニングデータを学習する
model.fit(x_train, y_train, epochs=500)

# 結果をグラフにする
plt.scatter(x, y, label='data', color='gray')
plt.plot(x_train, model.predict(x_train), label='train', color='blue')
plt.plot(x_test, model.predict(x_test), label='test', color='red')
plt.legend()
plt.show()

このプログラムを実行すると、下記のグラフが表示される。
linear_regression1.png

トレーニングデータ、テストデータともフィットしており、うまく学習できたようだ。

最後に

まだまだKeras初心者の身であるが、使える機能を見つけたときには(メモがてら)Qiitaに書いていきたいと思う。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?