LoginSignup
14
10

More than 5 years have passed since last update.

iOSのMPSCNNによる手書き数字認識のサンプルを読む - 前編

Last updated at Posted at 2016-12-23

iOS 10でMetal Performance Shadersフレームワークに、CNN(Convolutional Neural Network)演算機能群が追加されました。iOSデバイスのGPUで畳み込みニューラルネットワークの計算をさせることができる、という代物です。iOSデバイスのGPUも日々進歩しているとはいえ、モバイルデバイスなのでさすがに学習までは厳しいところですが、学習済みのネットワークの計算であれば使えるケースはありそうです。

このMPSCNNのApple公式サンプルコードとして、CNNで手書き文字認識を行うサンプルが公開されています。

ただこちらは一時期ずっとSwiftのバージョンの問題でビルドできなかったりした事情もあり、iOS-10-Samplerというサンプルコード集アプリに移植したものがあるので、今回はそちらを追っていきます。

多少リファクタリングしていますが、基本的な実装は同じです。

オーバービュー

まずはCNNの中身に踏み込む前に、大局を把握するために、そのネットワークを利用するアプリケーション側(iOS-10-SamplerではMetalCNNBasicViewController)の実装から見ていきます。

ネットワークの入力部分について

上でネットワークの中身にはまだ踏み込まない、と書いたばかりですが、入力部分がどうなっているかだけ見ておきます。

本サンプルでの手書き文字認識のネットワークは、MNISTDeepCNNというクラスとして実装されています。

これがsrcImageというMPSImage型のプロパティを持っていて、

MNISTDeepCNN.swift
var srcImage : MPSImage

ネットワーク初期化時にsrcImageも初期化されます。

MNISTDeepCNN.swift
srcImage = MPSImage(device: device, imageDescriptor: sid)

ここに、以下で説明する実装によって、手書きした内容が入力として渡されてきます。

1. 手書きした内容を、入力としてネットワークに渡す

まず、手書きしたビューの描画内容をCGContext(28x28のビットマップコンテキスト)として取得します。1

ViewController.swift
guard let context = digitView.getViewContext() else {return}

取得したビットマップコンテキストのdataプロパティからピクセルデータへのポインタ(UnsafeMutableRawPointer)を取得し、ネットワークに渡します。

ViewController.swift
network.srcImage.texture.replace(
    region: region,
    mipmapLevel: 0,
    slice: 0,
    withBytes: pixelData,
    bytesPerRow: inputWidth,
    bytesPerImage: 0)

ここで、networkは(本サンプルにおけるネットワーク実装クラスである)MNISTDeepCNNのインスタンスで、 srcImageは上で説明した通りMPSImage型のプロパティです。replace〜というメソッドを使って、手書き内容のピクセルデータをMPSImageのtestureプロパティにコピーする、というのが上記コードでやっていることになります。

ちなみに、MPSImagetextureプロパティは、

MPSImage
open var texture: MTLTexture { get }

この定義の通り、MPSImagetexture プロパティはread-onlyなので、直接代入するのではなく、replace〜というメソッドを利用してピクセルデータをMTLTextureにコピーする、ということをやっています。

MTLTexture
/*!
 @method replaceRegion:mipmapLevel:slice:withBytes:bytesPerRow:bytesPerImage:
 @abstract Copy a block of pixel data from the caller's pointer into a texture slice.
 */
public func replace(region: MTLRegion, mipmapLevel level: Int, slice: Int, withBytes pixelBytes: UnsafeRawPointer, bytesPerRow: Int, bytesPerImage: Int)

2. ネットワークの実行

入力画像を渡したら、あとはもう順方向伝播(forward propagation)の計算を実行して結果を受け取るだけ。

ViewController.swift
let label = network.forward()

アプリケーション側(畳み込みニューラルネットワークを利用する側)の実装としては非常にシンプルに済むことがわかります。

つづく

ネットワークの中身の実装については、後編『iOSのMetalで畳み込みニューラルネットワーク - MPSCNNを用いた手書き数字認識の実装』に続きます。


  1. Core Graphicsの話なのでgetViewContextの実装内容については省略します 

14
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
10