注意書きなど
この記事の説明はかなり簡略化されているため、論文の本質とは若干異なるかもしれません。所々DeepL使って翻訳してる箇所もあるので日本語が怪しいかもしれません。流し読み用にご利用くださいませ。
このモデル(ACT)の何がすごいの?
あのかの有名なMobile ALOHAで使われてます
Action Chunking with Trasnformers (ACT) はわずか10分間のデモンストレーションで料理などの難しいタスクを80〜90%の成功率で実行可能
アーキテクチャ
全体の構造
このモデルの全体は 条件付き変分オートエンコーダー(CVAE) になります
- CVAEについては後ほど説明します
Action ChunkingとTemporal Ensemble
この論文では、心理学の「チャンキング」にを元にしたaction chunkingというものが使われています
-
大きなタスクをより小さいタスクに分割するイメージ
-
具体的には次のタイムステップだけを予測するのではなく、$k$タイムステップのアクションを予測
チャンクを一つずつ再生すると動きがぎくしゃくする可能性があるため、この論文ではTemporal Ensembleという工夫を施しています
- Temporal Ensembleでは:
- 毎タイムステップごとにチャンクを予測
- アクションチャンクが重なり合うため、重み付き平均をタイムステップごとに取ります
- 重みは以下の関数で分布されます: $w_i = \exp(-m * i)$ ここで $i = 0$ は最も古い予測になります
重み分布
Action ChunkingとTemporal Ensembleの図(原文より引用)
ACTの具体的な説明に移る前に、このモデルを構成する要素を簡単に見ていきます:
- 条件付き変分オートエンコーダー(CVAE)
- ResNet
- トランスフォーマー
CVAEとは
条件付き変分オートエンコーダー(CVAE) は、入力データを圧縮し、圧縮された表現(この表現は潜在空間と呼ばれます)から入力データを再現しようとするモデルです。CVAEを標準のVAEと区別する特別な点は、エンコーダーとデコーダーの両方が追加情報に条件付けられることです。
構造は2つの部分から成ります:
- エンコーダー
- エンコーダーは入力データを、潜在空間$z$の確率分布を表す$\mu$と$\sigma$に変換します。
- デコーダー
- デコーダーは、潜在空間の分布からサンプリングして元の入力を再構築します。
再パラメータ化トリック (Reparametrization Trick)
確率分布からのサンプリングは微分不可能なので、次の式を使用してサンプルを近似します:
$$
z = \mu + \epsilon\sigma \hspace{3px} (\epsilon \sim N(0,1))
$$
CVAEと再パラメータ化トリックの図(AutoEncoder, VAE, CVAEの比較 〜なぜVAEは連続的な画像を生成できるのか?〜より引用)
損失関数
CVAEの損失関数は、証拠下界(ELBO)と呼ばれます:
$$
\mathcal{L} = \mathbb{E}_ {q(z|X)}[\log{p(X|z)}] - D_{KL}[q(z|X)||p(z)]
$$
ResNetとは
残差ニューラルネットワーク(ResNet) は、CNNアーキテクチャ(簡単に言うと画像から特徴を抽出するモデル)であり、残差接続という技術を使用して消失勾配問題を軽減することで有名です
-
残差接続 - 入力を畳み込みブロックの出力に追加するスキップ接続を使用します。これにより、ネットワークは入力と目標出力の違い(残差)を学習することができ、効率的なトレーニングと消失勾配問題の軽減が可能になります
- 線形層で勾配消失問題が起きても大丈夫なようにスキップ接続で避難経路を用意してあげてるイメージ
残差接続の図
トランスフォーマーとは
トランスフォーマー は若干複雑ではありますが、重要なのはデータの注意・注目すべき部分を抽出/強調する注意機構(Attention)の集まりであることと、シーケンス(系列)を取り扱うことです
トランスフォーマー は通常、エンコーダー + デコーダー構造を持ちます:
- エンコーダーは、シーケンスの要素間の依存関係を捉えることで入力シーケンスを表現することを学習します。
- デコーダーは、エンコーダーで学習した表現を使用して出力シーケンスを生成します。
ACTの全体的なフロー
ACTがどのように機能するかの一般的なフローは次の通りです:
- デモデータセットから画像($480 \times 640 \times 3$)と関節位置データ($1 \times 14$)を取得
a. 関節位置データを使用して $k$ タイムステップのアクションシーケンス($k \times 14$)を作成 - 関節位置データを$1 \times 512$に変換
- アクションシーケンスを$k \times 512$に変換し、正弦波位置埋め込み(Positional Encoding)を行う
- 関節位置データ、アクションシーケンス、および[CLS]*を4つの自己注意ブロックを持つトランスフォーマーエンコーダーに通す
- トランスフォーマーの出力を線形層に通して$\mu$($1 \times 32$)と$\sigma$($1 \times 32$)を取得
- 再パラメータ化トリックを使用してスタイル変数$z$をサンプリング
- 画像特徴を抽出
a. 画像をResNet18に通す(出力は$15 \times 20 \times 728$)
b. $300 \times 728$にフラット化
c. $300 \times 512$に変換
d. 正弦波位置埋め込みを行う - 画像特徴、関節($1 \times 512$)、および$z$($1 \times 512$に変換)を4つの自己注意ブロックを持つトランスフォーマーエンコーダーに入力
- トランスフォーマーデコーダー(7つのクロスアテンションブロックで構成)を8.のトランスフォーマーエンコーダーの出力で条件付け
- 固定位置埋め込み($k \times 512$)をトランスフォーマーデコーダーに入力し、出力アクションシーケンスを取得
*[CLS]はシーケンス全体の集約表現として機能するトークンになります。
ACTアーキテクチャの図(原文より引用)
損失関数
このモデルの全体的な損失は次のように与えられます:
$$L=MSE( \hat{a}_ {t:t+k},a_ {t:t+k})+βD_{KL}(q(z∣a_{t:t+k},\bar{o}_t)∣N(0,1))$$
ここで、
$k$は各アクションチャンクの長さ
$\hat{a}_ {t+k}$は出力アクションシーケンス
$a_{t+k}$は入力アクションシーケンス
$MSE$は平均二乗誤差
$\bar{o}_t$は映像無しの観察(センサデータ)
$D_{KL}$はKLダイバージェンス
$\beta$はハイパーパラメータ
- 再構成損失と正則化損失のバランスを制御
- $z$を通じて伝達される情報の量を制御する感じ
$q(z|a_{t+k}, \bar{o}_t)$はCVAEエンコーダー
になります
学習アルゴリズム
モデルは次の手順でトレーニングされます:
学習アルゴリズムのpseudocode(原文より引用)
- 初期化
- ハイパーパラメータ、データセット、エンコーダー、デコーダー
- データセットからのサンプリング
- 観察(関節位置と画像)とアクションシーケンス
- エンコーダー処理
- 観察(関節位置のみ)とアクションシーケンスを使用して再パラメータ化により潜在変数$z$をサンプリング
- デコーダー処理
- 観察(関節位置と画像)とサンプリングされた潜在変数を使用してアクションシーケンスを予測
- 損失の計算
- $L=MSE( \hat{a}_ {t:t+k},a_ {t:t+k})+βD_{KL}(q(z∣a_{t:t+k},\bar{o}_t)∣N(0,1))$
- パラメータの更新
- ADAMを使用してエンコーダーとデコーダーのパラメータを更新
推論アルゴリズム
モデルは次の手順で推論を行います:
推論アルゴリズムのpseudocode(原文より引用)
- 初期化
a. 事前学習されたデコーダー、エピソードの長さ、Temporal Ensembleの重み
b. デコーダーが予測した各タイムステップでのアクションを保存する先入先出(FIFO)バッファ - 事前学習されたデコーダー処理
a. 観察(関節位置とカメラ)を使用してアクションシーケンスを予測(潜在変数はゼロに設定) - アクションシーケンスをFIFOバッファに保存
- Temporal Ensembleを使用してタイムステップごとの最終的なアクションを取得
まとめ
- ACTの全体的な構造はCVAEでエンコーダ・デコーダ部分にトランスフォーマーを使っている
- Action ChunkingとTemporal Ensembleで各タイムステップで一つではなく、アクションチャンク(行動系列)を予測する
最後に
次も流れでDiffusion Policyとか読みたいけど一旦強化学習に戻ります。次回はDecision Transformer辺りを解説しようと思います。
参考文献
https://mobile-aloha.github.io/
https://arxiv.org/abs/2304.13705
https://ijdykeman.github.io/ml/2016/12/21/cvae.html
https://github.com/tonyzhaozh/act
https://qiita.com/kenchin110100/items/7ceb5b8e8b21c551d69a