Help us understand the problem. What is going on with this article?

iOSのMetalで畳み込みニューラルネットワーク - MPSCNNを用いた手書き数字認識の実装

More than 3 years have passed since last update.

MetalでCNNの計算を行うためのAPI群、MPSCNNを用いた手書き数字認識のサンプルを読む、という記事の続きです。

iOS 10でMetal Performance Shadersフレームワークに、CNN(Convolutional Neural Network)演算機能群が追加されました。iOSデバイスのGPUで畳み込みニューラルネットワークの計算をさせることができる、という代物です。

前編ではネットワークの中身には踏み込まず、オーバービューとして利用するアプリケーション側の実装について見ていきました。

後編となる本記事ではいよいよMetal Performance Shadersを用いたCNN(Convolutional Neural Network / 畳み込みニューラルネットワーク)の実装について見ていきます。

なお、CNN自体の解説はここでは省略しますが、概念をざっくり理解したい場合には下記記事が、

もうちょっとちゃんと学びたいときはCourseraの機械学習コース(無料)が大変わかりやすいです。

CNN各層を担うオブジェクトを生成

本サンプルではMNISTDeepCNNがネットワークの実装クラスになります。このクラスの初期化処理では、CNNの各層や活性化関数を担うオブジェクトを生成しています。

畳み込み層(Convolutional Layer)

畳み込み層を担うクラスとしてMPSCNNConvolutionというものが用意されています。

var conv1, conv2: MPSCNNConvolution

本サンプルでは、モデルデータ読み込み等をラップしたSlimMPSCNNConvolutionというサブクラス 1 を使用しています。

conv1 = SlimMPSCNNConvolution(kernelWidth: 5,
                              kernelHeight: 5,
                              inputFeatureChannels: 1,
                              outputFeatureChannels: 32,
                              neuronFilter: relu,
                              device: device,
                              kernelParamsBinaryName: "conv1")

conv2 = SlimMPSCNNConvolution(kernelWidth: 5,
                              kernelHeight: 5,
                              inputFeatureChannels: 32,
                              outputFeatureChannels: 64,
                              neuronFilter: relu,
                              device: device,
                              kernelParamsBinaryName: "conv2")

本ネットワークは畳み込み層を2つ持ち、

  • どちらもカーネルサイズは5x5
  • どちらも活性化関数はReLU
  • 1つ目の層は入力が1チャンネル(グレースケール画像)、出力が32チャンネル
  • 2つ目の層は入力が32チャンネル、出力が64チャンネル

ということが引数からわかります。"conv1", "conv2"は該当する学習済みパラメータの入ったバイナリデータファイルのプレフィックスです。

Rectified Linear Units(ReLU)

上で出てきた活性化関数ReLUを担うクラスが、MPSCNNNeuronReLUです。

var relu: MPSCNNNeuronReLU
relu = MPSCNNNeuronReLU(device: device, a: 0)

プーリング層(Pooling Layer)

MPSCNNPoolingMaxMPSCNNPoolingAverage というクラスが用意されています。本サンプルのネットワークでは最大値を取るプーリングを行っています。

var pool: MPSCNNPoolingMax
pool = MPSCNNPoolingMax(device: device, kernelWidth: 2, kernelHeight: 2, strideInPixelsX: 2, strideInPixelsY: 2)
pool.offset = MPSOffset(x: 1, y: 1, z: 0);
pool.edgeMode = MPSImageEdgeMode.clamp

ソフトマックス関数

MPSCNNSoftMax とそのまんまなクラス名です。

var softmax : MPSCNNSoftMax
softmax = MPSCNNSoftMax(device: device)

全結合層(fully-connected layer)

MPSCNNFullyConnected というクラスがMPSに用意されていますが、

var fc1, fc2: MPSCNNFullyConnected

本サンプルではモデル読み込み等の処理をラップした SlimMPSCNNFullyConnected というサブクラス 2 を使用しています。

fc1 = SlimMPSCNNFullyConnected(kernelWidth: 7,
                               kernelHeight: 7,
                               inputFeatureChannels: 64,
                               outputFeatureChannels: 1024,
                               neuronFilter: nil,
                               device: device,
                               kernelParamsBinaryName: "fc1")

fc2 = SlimMPSCNNFullyConnected(kernelWidth: 1,
                               kernelHeight: 1,
                               inputFeatureChannels: 1024,
                               outputFeatureChannels: 10,
                               neuronFilter: nil,
                               device: device,
                               kernelParamsBinaryName: "fc2")

各レイヤー間の入出力画像

ネットワーク全体の入力・出力、各層の出力(次の層への入力)となる画像の入れ物となる MPSImage インスタンスを、それぞれ別々に用意します。

var srcImage, dstImage : MPSImage
var c1Image, c2Image, p1Image, p2Image, fc1Image: MPSImage

MPSImageDescriptor でピクセルフォーマット、サイズ、特徴量の数を記述しておいて、

let sid = MPSImageDescriptor(channelFormat: .unorm8, width: 28, height: 28, featureChannels: 1)
let did = MPSImageDescriptor(channelFormat: .float16, width: 1, height: 1, featureChannels: 10)
let c1id  = MPSImageDescriptor(channelFormat: .float16, width: 28, height: 28, featureChannels: 32)
let p1id  = MPSImageDescriptor(channelFormat: .float16, width: 14, height: 14, featureChannels: 32)
let c2id  = MPSImageDescriptor(channelFormat: .float16, width: 14, height: 14, featureChannels: 64)
let p2id  = MPSImageDescriptor(channelFormat: .float16, width: 7 , height: 7 , featureChannels: 64)
let fc1id = MPSImageDescriptor(channelFormat: .float16, width: 1 , height: 1 , featureChannels: 1024)

ディスクリプタを渡してMPSImageを生成します。

srcImage    = MPSImage(device: device, imageDescriptor: sid)
dstImage    = MPSImage(device: device, imageDescriptor: did)
c1Image     = MPSImage(device: device, imageDescriptor: c1id)
p1Image     = MPSImage(device: device, imageDescriptor: p1id)
c2Image     = MPSImage(device: device, imageDescriptor: c2id)
p2Image     = MPSImage(device: device, imageDescriptor: p2id)
fc1Image    = MPSImage(device: device, imageDescriptor: fc1id)

フォワードプロパゲーション

forwardメソッドで、順方向伝播(forward propagation)の計算を行います。

初期化処理ではCNNの各層を担うクラス群や入出力画像の入れ物となるMPSImageを用意したわけですが、本メソッド内でこれらを実際に繋げて、入力画像から推測結果を出力する、ということを行います。

if let inputImage = inputImage {
    conv1.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: c1Image)
} else{
    conv1.encode(commandBuffer: commandBuffer, sourceImage: srcImage, destinationImage: c1Image)
}    
pool.encode   (commandBuffer: commandBuffer, sourceImage: c1Image   , destinationImage: p1Image)
conv2.encode  (commandBuffer: commandBuffer, sourceImage: p1Image   , destinationImage: c2Image)
pool.encode   (commandBuffer: commandBuffer, sourceImage: c2Image   , destinationImage: p2Image)
fc1.encode    (commandBuffer: commandBuffer, sourceImage: p2Image   , destinationImage: fc1Image)
fc2.encode    (commandBuffer: commandBuffer, sourceImage: fc1Image  , destinationImage: dstImage)
softmax.encode(commandBuffer: commandBuffer, sourceImage: dstImage  , destinationImage: finalLayer)

こうやってみるとなんともシンプルなネットワークです。各層を抜き出して並べると、

conv1 -> pool -> conv2 -> fc1 -> fc2 -> softmax

となっていることがわかります。

ここに、入出力画像も入れると、

(srcImage) -> conv1 -> (c1Image) -> pool -> (p1Image) -> conv2 -> (c2Image) -> fc1 -> (fc1Image) -> fc2 -> (dstImage) -> softmax -> (finalLayer)

という感じのネットワークになっています。

cnn.jpg

(本ネットワークとは別のものになりますが、イメージ図として、WWDC16のセッションスライドより)

ここで呼んでいる encode〜 というメソッドは、各層のクラスの基底クラスとなっているMPSCNNKernelが持っているメソッドで、Metalのコマンドバッファに処理を登録するものです。

MPSCNN
open func encode(commandBuffer: MTLCommandBuffer, sourceImage: MPSImage, destinationImage: MPSImage)

最終出力

最終出力(ソフトマックス関数の出力)となる finalLayer は次のように初期化されており、

let finalLayer = MPSImage(device: commandBuffer.device, imageDescriptor: did)

ここで渡されているディスクリプタ did は、(再掲になりますが)次のように生成されています。

let did = MPSImageDescriptor(channelFormat: .float16, width: 1, height: 1, featureChannels: 10)

したがって、0〜9の10種類のラベルについての確率が示された1x1の画像であることがわかります。

あとは、getLabelというメソッドで、この1x1の画像のピクセル値を読み取り、最終的な0〜9の数字を出力するのですが、この処理内容もまた別記事で書きたいと思います。

まとめ

MPSCNNのサンプルを題材に、どのように手書き数字認識のCNNが実装されているかを見てきました。

ざっくり&駆け足になってしまいましたが、細かい行列計算等をほとんど意識することもなく、GPU-AcceleratedなCNNを比較的簡単に構築できることが感じ取れたかと思います。

学習済みパラメータの読み込み等、今回書ききれなかったことはまた別記事で書きたいと思います。

(追記)書きました: MPSCNNに渡すモデルパラメータのフォーマット - Qiita


  1. この中身については長くなるので、また別記事で書きたいと思います 

  2. これも、中身については別記事で書きたいと思います。 

shu223
フリーランスiOSエンジニア 著書:『iOS×BLE Core Bluetooth プログラミング』『Metal入門』『実践ARKit』『Depth in Depth』『iOSアプリ開発 達人のレシピ100』他 GitHubの累計スター数23,000超
http://shu223.hatenablog.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした