0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

メタ学習( MAML )の1実装。

Last updated at Posted at 2024-10-24

勉強した動機

今まで、通常の機械学習で音声合成、音声認識、機械翻訳や画像キャプショニングで非自己回帰型のモデルについて学習を行い精度を測定しました。精度を上げるためには、学習データと学習パラメータを増やす方法が有効です。しかし、大企業や AI の有力スタートアップでもないかぎり、学習データ作成に投資したり、学習を行うマシンに投資したりすることはできません。そこで、メタ学習の少ないデータや少ない計算量で学習できるという特徴を確かめるべく、MAML を実装し、5-way 10-shot の画像分類について実際に学習を行い精度を測定しました。結果は、少ないデータ量、少ない計算量でそこそこの精度が得られました。

参考にさせていただいたページ

実装をするにあたり、勉強しました。インターネットページで勉強しました。特に、

のページ。この実装が置いてある

は、かなり参考になりました。この他に、

の実装を参考にさせていただきました。論文の解説は

を参考にさせていただきました。感謝いたします。

学習用データセット

データは、CIFAR100 から自前で学習用データセットを作りました。画像データは、[ outer_batch, num_task, inner_batch, N-way * k-shot, 3, 32, 32 ] の形で、ラベルデータは、[ outer_batch, num_task, inner_batch, N-way * k-shot ] の形で作るようにしました。 [ 3, 32, 32 ] は画像の次元です。

学習データは、CIFAR100 の学習用データのうち 10,000 データを用い、 outer_batch = 3, num_task = 20, inner_batch = 3, k_support=10, k_query=15, num_class = 5 としました。

学習結果

学習は 300 エポックで、train acc = 1.0, val acc = 0.63, test acc = 0.76 でした。学習にかかった時間は、RTX 6000 一枚で 30分程度、CPU でも3時間程度でした。

プログラムの要点

MAML の学習プログラムでは、タスクが重要な役割を果たす。画像認識では、CIFAR100 のデータで、0~99 のクラスについて、0番目のタスク 0~4, 2番目のタスク 5~9,・・・,19番目のタスク 95~99 と20個のタスクを考えた。あるタスクについて、データは support データと query データを作成する。加えて、勾配を二回計算する。あるタスクについて、初期値の学習パラメータで support データについて loss を計算し一つ目の勾配を求める。この勾配から lr を使って途中の学習パラメータを求める。途中の学習パラメータを query データに適用し、loss と二つ目の勾配を計算する。二つ目の勾配は1次近似で求める。この計算をすべてのタスクについて行い、二つ目の勾配の和を求める。二つ目の勾配の和を使って、モデルの学習パラメータ―を更新する。また、更新された学習パラメータを初期パラメーターとして、一つ目の勾配を求める。

pytorch の MAML でも、model パラメータを求めるのですが、validation や test を行うとき、保存した model パラメータで計算した loss と acc ではなく、保存したモデルパラメータについて、興味のある task を学習させたパラメータで、validation や test の loss と acc を計算するようです。

実装したプログラム

わたくしもそうだったのですが、ページであれこれ説明されても良く分からなかったです。分かるためには、実装したプログラムを使ってみて、自分が正しいと思うように修正すると分かったと思えました。ページ上で説明は控えさせていただき、理解したい方は、github のプログラムをダウンロードして使ってみて、自分なりに直してみてください。よろしくお願いいたします。

感情分析の MAML メタ学習プログラムも置いておきます。感情分析は、Transformer Encoder を使っています。<CLS>トークンは使わずに、テキストを tokeinze したあとの sequence の最大値を 128 として、Transformer Encoder の出力を [ batch, 128, 256] と固定しました。その結果、最終的な分類のための線形層を nn.Linear( 128*256, 2 )とすることができました。精度は train acc = 0.79, val acc = 0.74, test acc = 0.77 でした。学習にかかった時間は、300 epochs で RTX 6000 一枚で1時間程度でした。CPUでも3時間程度です。

このページを参考にしました。

感情分析について CLS トークンを用いたメタ学習も行ってみました。train acc = 0.72, val acc = 0.76, test acc = 0.70 でした。ソースを置いておきます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?