iOS 10でMetal Performance Shadersフレームワークに、CNN(Convolutional Neural Network)演算機能群が追加されました。iOSデバイスのGPUで畳み込みニューラルネットワークの計算をさせることができる、という代物です。iOSデバイスのGPUも日々進歩しているとはいえ、モバイルデバイスなのでさすがに学習までは厳しいところですが、学習済みのネットワークの計算であれば使えるケースはありそうです。
このMPSCNNのApple公式サンプルコードとして、CNNで手書き文字認識を行うサンプルが公開されています。
ただこちらは一時期ずっとSwiftのバージョンの問題でビルドできなかったりした事情もあり、iOS-10-Samplerというサンプルコード集アプリに移植したものがあるので、今回はそちらを追っていきます。
多少リファクタリングしていますが、基本的な実装は同じです。
##オーバービュー
まずはCNNの中身に踏み込む前に、大局を把握するために、そのネットワークを利用するアプリケーション側(iOS-10-SamplerではMetalCNNBasicViewController
)の実装から見ていきます。
###ネットワークの入力部分について
上でネットワークの中身にはまだ踏み込まない、と書いたばかりですが、入力部分がどうなっているかだけ見ておきます。
本サンプルでの手書き文字認識のネットワークは、MNISTDeepCNN
というクラスとして実装されています。
これがsrcImage
というMPSImage
型のプロパティを持っていて、
var srcImage : MPSImage
ネットワーク初期化時にsrcImage
も初期化されます。
srcImage = MPSImage(device: device, imageDescriptor: sid)
ここに、以下で説明する実装によって、手書きした内容が入力として渡されてきます。
###1. 手書きした内容を、入力としてネットワークに渡す
まず、手書きしたビューの描画内容をCGContext(28x28のビットマップコンテキスト)として取得します。1
guard let context = digitView.getViewContext() else {return}
取得したビットマップコンテキストのdata
プロパティからピクセルデータへのポインタ(UnsafeMutableRawPointer
)を取得し、ネットワークに渡します。
network.srcImage.texture.replace(
region: region,
mipmapLevel: 0,
slice: 0,
withBytes: pixelData,
bytesPerRow: inputWidth,
bytesPerImage: 0)
ここで、network
は(本サンプルにおけるネットワーク実装クラスである)MNISTDeepCNN
のインスタンスで、 srcImage
は上で説明した通りMPSImage
型のプロパティです。replace〜
というメソッドを使って、手書き内容のピクセルデータをMPSImageのtesture
プロパティにコピーする、というのが上記コードでやっていることになります。
ちなみに、MPSImage
のtexture
プロパティは、
open var texture: MTLTexture { get }
この定義の通り、MPSImage
の texture
プロパティはread-onlyなので、直接代入するのではなく、replace〜
というメソッドを利用してピクセルデータをMTLTextureにコピーする、ということをやっています。
/*!
@method replaceRegion:mipmapLevel:slice:withBytes:bytesPerRow:bytesPerImage:
@abstract Copy a block of pixel data from the caller's pointer into a texture slice.
*/
public func replace(region: MTLRegion, mipmapLevel level: Int, slice: Int, withBytes pixelBytes: UnsafeRawPointer, bytesPerRow: Int, bytesPerImage: Int)
###2. ネットワークの実行
入力画像を渡したら、あとはもう順方向伝播(forward propagation)の計算を実行して結果を受け取るだけ。
let label = network.forward()
アプリケーション側(畳み込みニューラルネットワークを利用する側)の実装としては非常にシンプルに済むことがわかります。
##つづく
ネットワークの中身の実装については、後編『iOSのMetalで畳み込みニューラルネットワーク - MPSCNNを用いた手書き数字認識の実装』に続きます。
-
Core Graphicsの話なので
getViewContext
の実装内容については省略します ↩