$\def\bm{\boldsymbol}$
概要
- オンラインで開催している『モデルベース深層学習と深層展開』読み会で得られた知見や気づきをメモしていく
- ついでに、中身の理解がてらJuliaサンプルコードをPythonに書き直したコードを晒していく
- 自動微分ライブラリにはJAXを使用する
第9回
大まかな内容
- LDPC符号の復号化アルゴの微分可能化
- スパース信号再構成問題への深層展開適用例の確認
- 線形システム安定化への深層展開適用
議論になったこと
5.2.1節 あたり
- ”誤り訂正”自体の説明が薄い…
- モチベーションはデジタル信号の送信ミスを直すと、シンプルだが奥深い世界があるっぽい
- シャノンの限界にいかに近づけるか、みたいな?
- 情報理論とも関連してる?
- 身近な所だと、オーディオ分野とかで活用されている
- 最近はハイレゾオーディオも配信とかされだしたが、この辺の技術が支えている?
- モチベーションはデジタル信号の送信ミスを直すと、シンプルだが奥深い世界があるっぽい
5.2.2節 あたり
- サンプルコード眺めながら
- alistはポピュラーな形式なのか?
- ->誰かの論文で使われてるとか?
- 対数事前確率比$\bm{\lambda}$がどのように与えれるかのイメージがよくわからん
- 実際に得られるのはノイズで汚れた符号ではないのか?
- ->答え出ず
- 実際に得られるのはノイズで汚れた符号ではないのか?
- alistはポピュラーな形式なのか?
- 著者のガチの専門分野で、執筆書籍が別にある分、逆に前提や問題設定の説明が薄め?
- タナーグラフって単純に因子グラフの一種?
- この辺の分野のみで呼ばれているもの?
- ->この資料参考になるかも
5.3.2節 あたり
- TISTAの方が収束がいい理由は?
- ->少なくとも教科書には記載ないが、提案論文でも実験的にしか示されてないっぽい?
- 結局ISTAと本質は同じで、亜種の域を出ないから逆に論理的に際を示しにくい説
- ->少なくとも教科書には記載ないが、提案論文でも実験的にしか示されてないっぽい?
5.4.1節 あたり
- $\bm{A,B,C}$はどっから出てきた?
- ->ここでは、単純にベンチマーク的にテキトーに作ってるかもしれないが、通常は物理的な特性(運動方程式など)を元にモデリングをする
5.4.2節 あたり
- 図5.9(b)は1周辺に落ち着くのがふさわしいんだっけ?0じゃないんだっけ?
- ->答え出ず
個人的な気づきなど
- 通常のビリーフプロパゲーションがなぜ、自動微分しづらいのかピンと来てなかったが、議論の中で「式(5.15)~(5.17)中の$\Sigma$や$\Pi$が特定の集合要素のみに作用しているため、計算グラフが繋がらないから」ではないかと思った。
プログラムでの理解
- 今回の範囲では、低密度パリティ検査符号(LDPC)の複合をビリーフプロバゲーション法を用いて行う際に、計算機上の全ての計算を行列とベクトルの積和演算に直すことで、全体を微分可能にし、深層展開を導入できるようにする方法を学んだ。
- 和田山先生公開のサンプルプログラムをPythonで再現する
- プログラムの全文はこの辺に上げている
問題設定
以下は、正直あまり詳しくないので自信はない、、(和田山先生の『誤り訂正技術の基礎』にはもっと正確で詳しく書いてある)
誤り訂正技術の世界観
- デジタル信号の通信に於いて誤り訂正という技術がある。
- 例えば、純粋なデジタル信号通信は0か1かだけを送り合えればいいのだが、通信路にはどうやっても雑音源がいるので、そのまま通信しようとすると送信元の信号は受信側に正しく伝わらない事が多い。
- そこで、0や1に対してあえて冗長な表現に符号化し、それを受信側で復号化するといった事が行われているらしい。
- 簡単な例では、「0は"000"、1は"111"と送ります」と決めておいて、受信側では「0か1の数の多い方をそれぞれ、0と1とみなす」とすれば、送った"111"が多少雑音で汚れて"101"や"011"になってしまっても、誤りを訂正して元の信号を復元する事ができる。
- ただし、この方法では当然"001"まで汚れてしまうと復元できないし、じゃあ数字の羅列を増やそうとかすると、限られた通信リソースをロバスト性のために犠牲にしてしまう事になる。
- そこで、効率よく符号化・復号化を行う技術が研究されている
- ただし、この方法では当然"001"まで汚れてしまうと復元できないし、じゃあ数字の羅列を増やそうとかすると、限られた通信リソースをロバスト性のために犠牲にしてしまう事になる。
- 例えば、純粋なデジタル信号通信は0か1かだけを送り合えればいいのだが、通信路にはどうやっても雑音源がいるので、そのまま通信しようとすると送信元の信号は受信側に正しく伝わらない事が多い。
2元線形符号
- 0,1で表されるデジタル通信を扱うのには2元線形符号というのを使う
- 最初、これをよく分かってなかったので、その先が全然ピンと来なかった…
- まず、集合$\{0,1\}$に対して、加算$+$と乗算$\cdot$が以下を満たすものを二元有限体$\mathbb{F}_2$という
- この二元有限体上の元を$n$個要素に持つベクトルが張るベクトル空間$\mathbb{F}_2^n$の部分線形空間を2元線形符号という
- 低密度パリティ検査符号(LDPC符号)はこの2元線形符号の1種である
因子グラフとsum-productアルゴリズム
- 因子グラフは多変数関数の因子分解を表すグラフである。
- 例)$f(x,y)=xy+y^2$は$f_1(x,y)=x+y$と$f_2(y)=y$の2つの因子を使って$f=f_1(x,y)f_2(y)$と分解される
- この因子グラフ上で多変数関数の周辺化計算を効率的に行う方法としてsum-productアルゴリズムがある
- 因子グラフは、因子に確率関数、変数に確率変数をとると確率のグラフィカルモデルとなる
- 推論の一種として、任意の周辺確率の計算をするためのツールとなる
- このときの周辺化計算のための情報伝達をメッセージパッシングと呼んだりする
- 推論の一種として、任意の周辺確率の計算をするためのツールとなる
LDPC符号の誤り訂正のためのビリーフプロバゲーション複合法
- 受信側で複合を司る行列$\bm{H}\in\mathbb{F}_2^{m\times n}$を検査行列という
- 疎な$\bm{H}$を使って、定義される以下の集合$C$がLDPC符号
- $C(\bm{H})=\{ \bm{x}\in \mathbb{F}_2^n|\bm{H}\bm{x}=\bm{0}\}$
- 疎な$\bm{H}$を使って、定義される以下の集合$C$がLDPC符号
- このとき、$\bm{x}$を確率変数として、$\bm{H}$と受信ベクトルが与えられたときに、送信ベクトルの最尤推定を行うが誤り訂正のいち手法があり、これを因子グラフ上のメッセージパッシングの繰り返しで行うのがビリーフプロバゲーション複合法である(と思う)
- このとき使われる因子グラフを特にタナーグラフと呼んでいるようだ
- 具体的な計算は教科書の式(5.15)~(5.17)あたりを参照
微分可能化
- 元々のビリーフプロバゲーションは繰り返し計算であるが、微分可能な演算ではないため、深層展開と相性が悪い。
- そこで、全ての演算を行列とベクトルの演算で済むように書き直してあげたというのが教科書の内容
- (さらに、それを応用して、深層展開と組み合わせいるサンプルコードもあるのだが、なぜかその内容は教科書本文ではあまり解説されていない…)
- ざっくりいうと、タナーグラフを2つの接続行列に分解(図5.3)するのと、対数の性質をいい感じに利用するのがポイントっぽい(より詳細は教科書のアルゴリズム4あたりを参照)
- そこで、全ての演算を行列とベクトルの演算で済むように書き直してあげたというのが教科書の内容
Pythonで実装
- ここでは、jaxnumpyでビリーフプロバゲーションを実装するところまでなので注意。(学習はしてない)
必要ライブラリインポート
import alist_loader
import jax.numpy as jnp
- alist_loader.pyは適当に自作したalistファイルのローダー
- 詳しくはここを参照
alistファイルの読み込み
filename = "../../DU-Book/Chapter_5/6.3.alist"
H, U, V = alist_loader.load_alist(filename)
- ここでは本家の和田山先生のリポジトリにある.alistファイルを読み込んでいる
- .alist形式は、検査行列等を表現したフォーマットだと思うのだが、このサンプルコードではじめて知った。(詳しくはよく分かっていない)
与えられた対数事前確率比と分散
var = 0.794328
Lambda = 2.0*jnp.array([1.620803, 0.264281, -0.031637,
-0.127654, 0.746347, 1.003543])/var
αとβの初期化
esize = U.shape[-1]
alpha = jnp.zeros(esize)
beta = jnp.zeros(esize)
sum-productアルゴの繰り返し計算
for i in range(1,4):
beta = (U.T@U-jnp.eye(len(U.T))) @ alpha + U.T@Lambda
tmp = jnp.exp((V.T@V-jnp.eye(len(U.T))) @ jnp.log(jnp.abs(jnp.tanh(beta/2))))
alpha_abs = 2*jnp.arctanh(tmp)
tmp = 1 -2*V.T @ bmod(V@((-jnp.sign(beta) + 1)/2))
alpha_sign = tmp * jnp.sign(beta)
alpha = alpha_sign * alpha_abs
gamma = U@alpha + Lambda
- 中身は愚直に教科書の行列演算
答え合わせ
- 上の
gamma
をプリントすると- [4.397419 1.6925247 1.7110999 1.7111 2.083964 2.703277 ]となった
- 対数事後確率比の真地は以下らしいので、よく一致している
- 4.397419 1.692525 1.711101 1.711101 2.083965 2.703277
その他
- 参加者増やしたい…
バックナンバー
- 『モデルベース深層学習と深層展開』読み会 レポート(開催前準備編)
- 『モデルベース深層学習と深層展開』読み会レポート#0
- 『モデルベース深層学習と深層展開』読み会レポート#1
- 『モデルベース深層学習と深層展開』読み会レポート#2
- 『モデルベース深層学習と深層展開』読み会レポート#3
- 『モデルベース深層学習と深層展開』読み会レポート#4
- 『モデルベース深層学習と深層展開』読み会レポート#5
- 『モデルベース深層学習と深層展開』読み会レポート#6
- 『モデルベース深層学習と深層展開』読み会レポート#7
- 『モデルベース深層学習と深層展開』読み会レポート#8