LoginSignup
19
17

More than 3 years have passed since last update.

テンソルネットワークを用いた量子インスパイアな機械学習

Posted at

はじめに

今回は、近年少しホットになっている量子インスパイア機械学習を紹介します。
ここでの"量子インスパイア"をもう少し正確に言うと、"量子系を古典計算でなるべく効率的にシミュレートするために用いられる手法から着想を得た"です。
具体的にはテンソルネットワークです。

背景

古典系による量子系のシミュレートにテンソルネットワークを用いる流れは以前からありました。多体量子系の状態を行列積状態を用いて効率的に計算する[1]、ゲート量子計算を無向グラフとテンソルネットワークの縮約の組み合わせで効率的にシミュレートする[2]、などです。

これらの手法が示すように、テンソルネットワークを用いることで量子系が持つ非常に高次元の空間をあくまで近似的にですが、古典計算で扱うことができます。
高次元であることは機械学習においては表現力の高さにつながります。

この特徴を量子系のシミュレートだけでなく古典機械学習問題に適用しようというのが今回紹介する"量子インスパイアな機械学習"となります。

行列積状態(復習)

行列積状態はMPSと呼ばれ、下記図のように表されます[1]。

image.png

各黒丸は"site"と呼ばれ、N qubit系であればN個のsiteが作られます。
$\sigma_i$は"physical index"と呼ばれ、qubitの場合は$0\ (|0\rangle)$ or $1\ (|1\rangle)$を指します。各siteは通常の2次元行列にphysical indexの次元を加えた3次元行列(テンソル)と考えます。

各siteはそれぞれphysical index=0に対応する2次元行列と、physical index=1に対応する2次元行列を持つとイメージすることもできます。
仮に全ての$i$について$\sigma_i = 0$の場合、全てのsiteのphysical index=0に対応する2次元行列同士の積を計算することができ、その結果はもとの量子状態における$|00...0>$の係数となります。

どのように機械学習を行うか

教師あり学習による分類問題を扱った代表的な論文[3],[4]をもとに紹介します。
全体的なフローを図に示します。

スクリーンショット 2020-10-20 12.04.03.png

まず、入力データ$x$を行列積状態の図の$\sigma_i\ (i\in 0,...,n-1)$にエンコードします。

エンコードしたデータとMPSとの間でテンソル縮約をとります。
さらにMPSのsite間のエッジを縮約するのですが、このままでは得られるのは1つのスカラー値なので分類に使えません。
そのためMPSにはあらかじめ、$x$が各クラスに属する確率に対応する値を出力するための"label index"を持たせておきます。label indexは既存のsite 1つ、またはlabel indexの保持用に新たに追加したsite 1つに持たせます。
このようにすると、$\sigma_i$とMPSのすべてのエッジの縮約を計算した結果、判別クラス数と要素数の等しいテンソルが残るためその値を損失関数に入力できます。

学習時は損失関数の出力が小さくなるよう、MPSの各要素を更新します。
更新の方針は大きく分けて2通りあります。
1つは[3]で採用されている方法で、DMRGという従来の手法の応用です。隣り合う2 siteのみを変数とした局所的な最適化による更新をsweepしながら繰り返します。
もう1つは[4]で採用されており、誤差逆伝搬法を用いてMPSの全要素を更新します。

前者の手法は更新の際にSVDを用いて余剰次元を動的に刈り込めるメリットがあります。
一方後者の手法は既存のDL, 自動微分フレームワークによる計算との相性が良く、またおそらくネットワーク構造や損失関数の定義などの自由度が高いです。

実装

今回は[4]で行われた、誤差逆伝搬法によるMNIST学習を実装しました。
実装には著者らが開発したTensornetworkというpythonモジュールを使用しました。

Tensornetworkは文字通りテンソルネットワークの計算に適したライブラリです。
バックエンドとして"tensorflow"と"jax"を選択できます。"tensorflow"を選択した場合はTensorflowフレームワークと組み合わせて学習できます。
Tensorflowの自動微分機能や組み込みの関数を使用できる点が便利なのですが、一方で書いてみるとほとんどがカスタムレイヤーで占められてしまうため、フレームワークに合わせて書く面倒さやフレームワーク自体のオーバーヘッドが気になる面もあります。

そこで今回はjaxバックエンドを採用しています。
実際に[4]に続く研究[5]ではjaxバックエンドを使用しているようです。
jaxもpythonフレームワークで、おおざっぱに言うと自動微分、JIT、ベクトル化による並列演算をサポートしたnumpyのようなものです。
Tensorflowの高速な自動微分だけシンプルに使いたい、という場合に良い選択肢なのではないでしょうか(JuliaのFluxなども似たような立ち位置だと思っていて、そういった需要はそれなりにあるのでしょう)。

私の実装は[4]とは以下の点でやや異なっています。
1. MNISTの画像データを2x2 average poolingしている。
2. optimizerは[4]で使用されたadamでなく単純な勾配降下法(それに伴い、学習率やEpoch数も調整)

1.については、オリジナルのサイズだと学習の難易度が高かったためです。縮約をとる際に行われるのは画素数分の行列のかけ算であり、かける行列の数が増えると出力値が発散 or 0収束しやすい、また勾配が消失しやすいなどのプラクティカルな難しさがあります。調整次第ではあると思うのですが今回は妥協しました。
また[5]では(ネットワーク構造やタスクがいくらか異なるためかもしれないですが)著者らもpoolingしています。

2.はサンプル実装をシンプルにするためです。
また手元でtensorflowバックエンドで書いた時にadam optimizerでトライした結果と比較して、勾配降下法が特に劣っていなかったという経緯もあります。

実装コードは以下に置きました。
https://github.com/ryuNagai/MPS/blob/master/TN_ML/MNIST_ML_jax.ipynb
※初回学習実行時、JITコンパイルが3分くらいかかります。どこか改善の余地ありかもしれません。

学習過程はこんな感じです。
image.png

最終的にtrain accuracy=0.962、test accuracy=0.952となりました。
[4]では50 epoch程度でtrain accuracy=0.98程度に到達しており、それを再現するには及ばない結果でした。
裏では[4]の結果が再現できないか条件近くして多少試したのですが、中々難しかったので一旦この値で良しとします。

まとめ

新たに流行る(かもしれない)テンソルネットワークを用いた量子インスパイア機械学習を実装しました。
量子コンピュータのハードウェア面に多くの制約がある現状で、こちらの手法は古典計算機上で実行できるため大きな問題も扱えます。
この手法を用いて従来の機械学習モデルより有用なモデルが多く発見されるかはまだこれからの研究次第だと思います。

加えて、古典機械学習と比較した優位性が量子空間を用いた機械学習にあるかどうか、このような手法を用いることでそれが近似的・間接的にでも見える、検証できる可能性があれば良いと思っています。

参考文献

[1] https://arxiv.org/abs/1008.3477
[2] https://arxiv.org/abs/1805.01450
[3] https://papers.nips.cc/paper/6211-supervised-learning-with-tensor-networks
[4] https://arxiv.org/abs/1906.06329
[5] https://arxiv.org/abs/2006.02516

19
17
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
17