1. Qiita
  2. 投稿
  3. iOS

iOSのMPSCNNによる手書き数字認識のサンプルを読む - 前編

  • 8
    いいね
  • 0
    コメント

iOS 10でMetal Performance Shadersフレームワークに、CNN(Convolutional Neural Network)演算機能群が追加されました。iOSデバイスのGPUで畳み込みニューラルネットワークの計算をさせることができる、という代物です。iOSデバイスのGPUも日々進歩しているとはいえ、モバイルデバイスなのでさすがに学習までは厳しいところですが、学習済みのネットワークの計算であれば使えるケースはありそうです。

このMPSCNNのApple公式サンプルコードとして、CNNで手書き文字認識を行うサンプルが公開されています。

ただこちらは一時期ずっとSwiftのバージョンの問題でビルドできなかったりした事情もあり、iOS-10-Samplerというサンプルコード集アプリに移植したものがあるので、今回はそちらを追っていきます。

多少リファクタリングしていますが、基本的な実装は同じです。

オーバービュー

まずはCNNの中身に踏み込む前に、大局を把握するために、そのネットワークを利用するアプリケーション側(iOS-10-SamplerではMetalCNNBasicViewController)の実装から見ていきます。

ネットワークの入力部分について

上でネットワークの中身にはまだ踏み込まない、と書いたばかりですが、入力部分がどうなっているかだけ見ておきます。

本サンプルでの手書き文字認識のネットワークは、MNISTDeepCNNというクラスとして実装されています。

これがsrcImageというMPSImage型のプロパティを持っていて、

MNISTDeepCNN.swift
var srcImage : MPSImage

ネットワーク初期化時にsrcImageも初期化されます。

MNISTDeepCNN.swift
srcImage = MPSImage(device: device, imageDescriptor: sid)

ここに、以下で説明する実装によって、手書きした内容が入力として渡されてきます。

1. 手書きした内容を、入力としてネットワークに渡す

まず、手書きしたビューの描画内容をCGContext(28x28のビットマップコンテキスト)として取得します。1

ViewController.swift
guard let context = digitView.getViewContext() else {return}

取得したビットマップコンテキストのdataプロパティからピクセルデータへのポインタ(UnsafeMutableRawPointer)を取得し、ネットワークに渡します。

ViewController.swift
network.srcImage.texture.replace(
    region: region,
    mipmapLevel: 0,
    slice: 0,
    withBytes: pixelData,
    bytesPerRow: inputWidth,
    bytesPerImage: 0)

ここで、networkは(本サンプルにおけるネットワーク実装クラスである)MNISTDeepCNNのインスタンスで、 srcImageは上で説明した通りMPSImage型のプロパティです。replace〜というメソッドを使って、手書き内容のピクセルデータをMPSImageのtestureプロパティにコピーする、というのが上記コードでやっていることになります。

ちなみに、MPSImagetextureプロパティは、

MPSImage
open var texture: MTLTexture { get }

この定義の通り、MPSImagetexture プロパティはread-onlyなので、直接代入するのではなく、replace〜というメソッドを利用してピクセルデータをMTLTextureにコピーする、ということをやっています。

MTLTexture
/*!
 @method replaceRegion:mipmapLevel:slice:withBytes:bytesPerRow:bytesPerImage:
 @abstract Copy a block of pixel data from the caller's pointer into a texture slice.
 */
public func replace(region: MTLRegion, mipmapLevel level: Int, slice: Int, withBytes pixelBytes: UnsafeRawPointer, bytesPerRow: Int, bytesPerImage: Int)

2. ネットワークの実行

入力画像を渡したら、あとはもう順方向伝播(forward propagation)の計算を実行して結果を受け取るだけ。

ViewController.swift
let label = network.forward()

アプリケーション側(畳み込みニューラルネットワークを利用する側)の実装としては非常にシンプルに済むことがわかります。

つづく

ネットワークの中身の実装については、後編『iOSのMetalで畳み込みニューラルネットワーク - MPSCNNを用いた手書き数字認識の実装』に続きます。


  1. Core Graphicsの話なのでgetViewContextの実装内容については省略します