LoginSignup
4
3

More than 3 years have passed since last update.

fast.aiでギターを識別するモデル

Posted at

今の時代なら、オンラインコースで新しい技術を身に付けることができるようになりました。最近CourseraでAndrew Ngが教えるDeep Learning specializationを完了して、非常に勉強になりました。数学の説明が多くて、最初のモデルは完全にpython・numpyで作ります。レッスンを続くと、Tensorflow、Kerasも使うことがあります。理論の根拠をしっかり勉強して、活用する習い方のスタイルになります。

ただし、数年前から話題になっているfast.aiはまったく逆のやり方にです。「まずやろうぜ」のアプローチです。やってみることによって、fast.aiの基本のライブラリでどんなに簡単に人工知能モデルを作れるか刺激的です。

fast.aiのレッスン 1 のコードに基づいて、ギターを識別できるモデルを作ってみました。
ギターはあまりわからない方に申し訳ないですが、この例はあまり役に立たない可能性が高いです!

まず、モデルを準備する前に、4つのタイプのギターの画像データを収集しました。
- Gibson explorer
- Gibson Les Paul
- Fender Stratocaster
- Fender Telecaster

なぜその4つのタイプにしたかというと、ボディの形もヘッドの形も特徴があるギターたちです。また、大人気のギターたちなので、画像収集は簡単でした。Google検索、ありがとうございました。
それぞれのギターのタイプに対してフォルダーを作って、数百枚を集めました。

まずfastai.visionをインポートします。
データを保存したパスを設定し、クラスを定義します。クラス名はフォルダ名と一致します。


from fastai.vision import *

path = Path('data/guitars')
classes = ['gibson_les_paul', 'fender_telecaster', 'fender_stratocaster', 'explorer']

fast.ai はデータを簡単に準備するメソッドを提供しています。一つのメソッドで画像に反映する変更(たとえばdata augmentationのための回転など)、サイズ変更、またトレーニングデータセットと検証データセットの割合を設定できます。
詳細についてAPI詳細のページをご参照してください。


data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
                                 ds_tfms=get_transforms(), size=224, num_workers=4, bs=16).normalize(imagenet_stats)

データの例も表示することが可能です!


data.show_batch(rows=3, figsize=(7,8))

Images.png

学習自体は非常に簡単です。コード2行で学習済みのCNNモデルに基づいて、学習させることとおができます。


learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)

学習の結果は下記です。90%の制度!ギター識別エンジンのState of the Artじゃないですか?唯一のモデルかもしれませんが。

TrainingDone.PNG

また、学習時の問題や検証データセットの結果も見ることができます。


interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

ConfusionMatrix.PNG

Confusion Matrixです!TelecasterとStratocasterは一番混合されることは確かにそうでしょう。
また、モデルが適切に予測できなかった画像も表示できます。


losses,idxs = interp.top_losses()
interp.plot_top_losses(9, figsize=(15,15))

Guitars.PNG

確かに判断しにくい画像が多いです。左上の画像そはそもそもギターケースです。真ん中の一番右はおそらくStratocasterでしょうけど、画像がクロップされすぎてモデルには難しいでしょう。右下の画像は人間にも判断しにくいです。その他のギターですね。つまり、学習データと検証データにノイズがあります。「Garbage in, Garbage out」。今後の改善として、データをさらにきれいにして学習します。また、インポートしたモデルのパラメータ固定を解除にすることによってさらに学習することが可能ですが、当然に時間がかかります。

モデルで予測してみましょう。


img = open_image(path/'gibson_les_paul'/'00000059.jpg')
pred_class,pred_idx,outputs = learn.predict(img)

このギターです。

download.png

予測結果は: Category gibson_les_paul

:v:ピポンピポン:v:

マイアプリで作ったモデルを使うには、モデルをエクスポートして簡単なpythonコードでAPI化できます。

上記の内容はほとんどFast.ai レッスン 1 の内容になります。そんなに簡単にそんなに早くなかなかいい結果出せることはずごいですね。しかもレッスンは無料です!おすすめです!Fast.aiのレッスンにはこのリンクをクリックしてください!

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3