14
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AIが学習するしくみ 〜 コードを添えて

Posted at

AIブーム、ますます加熱してますね。ChatGPTなどのAIサービスを使っている人も多いと思いますが、エンジニアとしては使うだけでなく、中身とか原理も知りたいですよね?そういうエンジニアのために、AIが学習するしくみを簡単に説明します。理論だけでなくコードを添えて、より具体的に想像できるように書きたいと思います。

なおここで「AI」はニューラルネットワーク(以下NN)を使った機能を指すこととします。想定する読者は、プログラミングの基礎知識を持っているけれども機械学習ライブラリは使ったことがない人です。

AIサービスの大雑把な中身

チャットや画像生成など、AIを使った様々なサービスが登場していますが、多くのサービスは複数のNNを組み合わせた構造になっています。NNについては下で詳しく説明しますが、入力に対して出力を生成するシステムであり、一種の関数のように振る舞います。それを組み合わせるので、例えばユーザーからリクエストを受け取ったら最初のNNでリクエストの種類を判別して次に使うべきNNを選んだり、入力を変えながら同じNNを繰り返し使ったりといった構造が考えられます(図1)。

AIサービスの構造例.png

図1. AIサービスの構造の例(注:こういう構造も有り得るというだけで、特定のサービスを想定したものではありません)

ニューラルネットワークとは

NNはその名の通り脳細胞を模したたくさんのノードのネットワークになっています(図2)。そのうちいくつかのノードが入力を受け取り、接続された後続のノードに信号を伝え(あるいは遮断し)、最終段のノードの状態から出力を決めます。この「段」をレイヤと呼びます。また各ノードの処理を左右するパラメータがあり、同じ入力でもパラメータを変えると出力が変わります。

NN.png

図2. 3層の密結合レイヤで構成されたNN

NNの特性はレイヤやノードの数、接続方法などで変わります。最も基本的な接続方法は、各レイヤの全ノードが前のレイヤの全ノードから信号を受け取る形で、そのような接続のレイヤを密結合(dense)レイヤとかリニアレイヤと呼びます。例えば前のレイヤのノードが100個で、次のレイヤのノードも100個の場合、レイヤ間に10,000の接続があることになります。

パラメータにも種類がありますが、単に「パラメータ」と言ったときは重み(weight)バイアスを指すことが多いと思います。重みは信号の増幅率で、すべての接続で独立した値です。つまり10,000の接続があったら10,000の重みパラメータがあることになります。バイアスは信号の処理結果を一方向へスライドするパラメータで、1ノードに1つです。

また実際のNNでは計算結果をそのまま次のノードへ送らず、活性化関数(activation function) と呼ばれる何らかの関数を通します。例えばReLUという関数は正の値はそのまま通しますが負の値を0に矯正します。活性化関数はNNの目的に応じて選択されます。

例として、3レイヤに3個ずつノードを持つNNをコードで表現してみます。機械学習ライブラリとしてPyTorchを使います。

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.layer1 = nn.Linear(3, 3)
        self.layer2 = nn.Linear(3, 3)
        self.layer3 = nn.Linear(3, 3)
    
    def forward(self, x):
        x = self.layer1(x)
        x = torch.relu(x)
        x = self.layer2(x)
        x = torch.relu(x)
        x = self.layer3(x)
        return x

__init__がこのNNのコンストラクタで、nn.Linearによって3つのノードを持つレイヤが生成されます。各ノードが3つの重みと1つのバイアスを持っているので、NN全体で36個のパラメータを持つことになります。パラメータの初期化方法によってNNの学習効率も変わりますが、この例では一定の範囲の乱数になっています。

forwardはこのNNに値を通すメソッドで、活性化関数としてReLUを使っています。

このNNに適当な値を入れて出力を得るコードと、その実行例を下に示します。

model = SimpleNN()

input_data = torch.tensor([[1.0, 2.0, 3.0]])

output = model(input_data)

print("Input:", input_data)
print("Output:", output)
Input: tensor([[1., 2., 3.]])
Output: tensor([[-0.8375, -0.1730, -0.1965]], grad_fn=<AddmmBackward0>)

学習していないNNに適当な値を入れただけなので、出力に意味はありません。しかし36個のパラメータを調整すれば、単純な構造で複雑な計算を可能にする可能性を持っています。

なぜニューラルネットワーク?

そもそもなぜ普通のプログラムを書かずNNを使うのかというと、原因(入力)と結果(出力)の関係が複雑すぎて普通の(記号論理学的な)プログラムを書くのが難しいという理由が挙げられます。記号論理学的な実装ができなくても、入力と望ましい結果さえわかっていれば、NNは目的の関数を模倣できます。

たとえばサバイバルゲームのようなチーム対チームの射撃戦を考えたとき、ある状況でどこへ移動すべきか、何を撃つべきか、論理的に考えてプレイするプログラムを書くのは非常に難しいでしょう。しかし自分の視野や武器、仲間の位置などを入力とし、被弾せず、相手を撃つことを望ましい結果とすれば、NNを学習させられます。

実際にDota 2というeスポーツゲームでは、2019年にAIが人間のチャンピオンチームを打ち負かしました

このように詳細なロジックがわからなくても欲しい結果だけ与えれば目的の関数を得られるのがNNの最大のメリットです。しかしそのためには構造(モデルと呼ばれます)を設計してデータを用意し、学習させる必要があります。ではその学習のしくみを見ていきましょう。

学習のしくみ

上では9個のノードからなるNNを例に挙げましたが、ここではさらに単純化し、1ノードが学習する様子について説明します。入力が1つで、重みとバイアスを持つとすると、ノードの処理は下の式で表せます。

$y = w * x + b$

$x$が入力、$y$が出力、$w$が重み、$b$がバイアスです。

前述の通り、AIの学習とは入力に対して出力を望ましい方向へ近づける最適化の工程です。分解すると下のようになります:

  1. パラメータの初期化
  2. 入力に対する出力の計算
  3. 望ましい値との差の計算
  4. パラメータの調整
  5. 2〜4の繰り返し

まずパラメータを初期化します。初期値によって学習にかかる時間も結果も変わるので、初期化の方法も研究されていますが、ここでは適当に$w=2、b=3$とします。

次に入力と望ましい結果(ここではターゲットと呼びます)が必要です。例として下の値の組み合わせを使います。

x target
1 4
2 5

出力を計算すると、下のようになります。

$x=1$のとき、$y=5$

$x=2$のとき、$y=7$

続いてターゲットとの差(損失(loss) と呼ばれます)を計算します。

$x=1$のとき、$4-5=-1$

$x=2$のとき、$5-7=-2$

差が大きいほど修正の必要も大きいことをシステムに反映するため、平均二乗誤差を使います:

$\frac{(-1)^2 + (-2)^2}2 = 2.5$

そしてパラメータの調整です。どのように調整すればいいでしょうか?当然、$y$がターゲットへ近づくように調整したいはずです。そのために各パラメータがどのように損失に影響したか解析します。面倒くさそうに聞こえますが、いわゆる機械学習ライブラリはこれを計算する機能を持っています(自動微分(autodiff) と呼ばれています)。またパラメータの影響を表すベクトルを勾配(gradient) と呼びます。$w$や$b$を、損失を求める関数のパラメータと考えると、その関数の偏微分によって勾配を得られます。

自動微分で得られる勾配は結果を増大させるベクトルです。よって損失のように値を減らしたい場合は反転して使います。例えば$w$が2のときの勾配が5だったら、0.05を引いて1.95にしたりします。この5に対する0.05のような割合を学習率(learning rate) と呼び、1回のパラメータの調整を更新(update) と呼びます。同様に$b$も更新します。

これで1回のサイクルが終わりました。学習率0.1で同様の更新を繰り返したところ、例題の場合は25回で損失が0.01以下になり、その時の$w$は約1.19、$b$は約2.69でした(パラメータを乱数で初期化しているので実行の度に変わります)。

以上のような最適化手法を勾配降下法(gradient descent) と呼びます。

コードで

上の手順をPyTorchで書くと下のようになります。

import torch

# 初期化
data = torch.tensor([[1.0, 4.0], [2.0, 5.0]])
x = data[:, 0]
target = data[:, 1]

w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

learning_rate = 0.1
epochs = 30

for epoch in range(epochs):
    # 順伝播
    y = w * x + b

    # 損失の計算
    loss = ((target - y) ** 2).mean()

    # 逆伝播
    loss.backward()

    # パラメータの更新
    with torch.no_grad():
        w.data = w.data - learning_rate * w.grad.data
        b.data = b.data - learning_rate * b.grad.data

    # 勾配をクリア
    w.grad.zero_()
    b.grad.zero_()

    print(f'Epoch {epoch+1}: loss = {loss.item()}')

print(f'Final parameters: w = {w.item()}, b = {b.item()}')

主要な部分を説明します。

まず初期化ですが、機械学習ライブラリにはスカラーやベクトル、多次元配列などを包括して共通のインターフェイスを提供するテンソル(tensor) と呼ばれる型があり、パラメータもテンソルで表現します。テンソルは値を表すと同時に、自動微分に必要な計算の過程(計算グラフと呼ばれます)も記憶できます。計算グラフが不要な変数でも様々な型を一元的に扱えて便利なのでテンソルを使うことが多いと思います。

順伝播(forward propagation) とはNNに信号を流して出力を得ることを指します。それに対して勾配を計算するには損失からパラメータへ逆方向に計算するので逆伝播(backpropagation) と呼びます。loss.backward()を実行すると、lossの計算に関わった全テンソルについて偏微分し、結果(つまり勾配)を各テンソルに格納します。例えばwの勾配はw.grad.dataで参照できるので、それを更新に使っています。

backward()は呼ぶ度に勾配を積算するため、更新が終わって不要になったときzero_()を呼んで勾配をクリアしておきます。

下に実行例を示します。

Epoch 1: loss = 2.5
Epoch 2: loss = 0.2649998366832733
Epoch 3: loss = 0.043299973011016846
Epoch 4: loss = 0.02086453326046467
Epoch 5: loss = 0.01816387288272381
Epoch 6: loss = 0.01743026077747345
Epoch 7: loss = 0.016904836520552635
(中略)
Epoch 23: loss = 0.010560772381722927
Epoch 24: loss = 0.010254897177219391
Epoch 25: loss = 0.009957837872207165
Epoch 26: loss = 0.009669391438364983
Epoch 27: loss = 0.009389298036694527
Epoch 28: loss = 0.009117312729358673
Epoch 29: loss = 0.008853225968778133
Epoch 30: loss = 0.008596784435212612
Final parameters: w = 1.1778438091278076, b = 2.712242603302002

様々なNN

NNはモデルによって特性が異なるので、目的にあったモデルを選択または設計する必要があります。

例えばCNN (Convolutional Neural Network)は連続したデータから特徴を抽出できるため、音声認識や画像認識に使われます。囲碁やチェスの盤面の評価にも使われ、囲碁の世界チャンピオンに勝ったAIのAlphaGoは、12レイヤのCNNを備えていました。

AIチャットボットの中核となっているのはTransformerというモデルで、翻訳や要約のような自然言語処理の研究の中で生まれました。Transformerにも密結合レイヤは使われていますが、柔軟性を高めるための補助的な役割しか負っておらず、中心となっているのは注意機構(Attention mechanism) と呼ばれる機能です。Transformerは入力された文章を単語に似た単位であるトークンの列に変換しますが、注意機構はトークン同士の関連の強さを判定し、文脈の情報を集めます。たとえば「Transformerは入力された文章を単語に似た単位であるトークンの列に変換します」という文なら、「入力」と「文章」や、「単語」と「トークン」を表すトークンの関連が強い、というような判定をします。ノードの接続や処理は密結合レイヤと全然違いますが、トークン同士の関連の強さや意味を学習するしくみは同じです。

実際の問題に応用するには

現実社会の課題をAIを使って解決しようとしたとき、まず問題になるのはどのようなデータ(入力やターゲット)を使うかです。データが不適切だとどんなに学習を重ねても望んだ性能には到達できません。データの質で結果が変わることを示した例として2023年にTextbooks Are All You Needという論文がありました。この研究では「教科書品質」のデータセットを使ってPythonコードを生成するモデルを訓練しました。「教科書品質」は「clear, self-contained, instructive, and balanced(明快で自己完結しており、教育的でバランスが取れた)」ものと定義されており、基準を満たすコードを得るためにインターネット上のコードをGPT-4で評価して選別したそうです。そのように訓練されたモデルは従来よりも少ないパラメータで高い性能を示しました。

品質以前に、学習に使えるデータが簡単に得られないこともあるでしょう。私はゲーム開発に携わっていますが、ユーザーのデータを集めようとすれば、プライバシーに関わる問題やサーバへの通信といった技術的な問題が発生します。

データの収集に目処が立っても、それを使った開発環境の構築にも課題があります。現実社会の課題に取り組むようなサービスの場合、データ収集と学習の過程に人間の介入が必要になることがよくあります。たとえば囲碁のようなゲームなら機械的に勝ち負けを判定できますが、接客が適切だったかどうかといった判断では人間の方が信頼できます。つまり「望ましい出力」の一部を人間から得るのです。別の例としては機械やアプリの操作を人間が先にやってみせて、それをAIが学ぶという流れも使われています。どちらも人間が介入する分、システムが複雑になります。

またデータの収集・加工・学習には大量の演算が必要なので、現実的な時間で処理できるように計算のパイプラインを構築する必要があります。1台のPCで処理できるケースは考えづらく、規模の拡大が容易なクラウドを使って様々なデータベースと適切な計算ユニット、およびそれらを管理するシステムを用意して初めてサービスの開発が可能になります。従って上で示した機械学習のしくみはAIサービスの開発に必要な理論の中ではほんの小さなかけらのような存在です。

まとめ

AIを使ったサービスが次々に登場し、良くできていて魔法のように見えることもあると思いますが、そういったサービスの根底にあるNNの学習について簡単なコードで説明しました。魔法のように見えていたものが少しでも現実に近づき、すっきりした気分になってもらえたら嬉しいです。

14
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?