LoginSignup
0
0

The Curse of Recursion: Training on Generated Data Makes Models Forget「再帰の呪い:生成データでの訓練がモデルに忘却をもたらす」

Last updated at Posted at 2023-08-05

今回は、査読前論文を取り扱ったarXivから興味を持った論文として、
The Curse of Recursion: Training on Generated Data Makes Models Forget「再帰の呪い:生成データでの訓練がモデルに忘却をもたらす」 を紹介します。

近年急速な発展を遂げている生成系AIが生成したデータは、人間が生成した文章や、自然界を写した実際の画像と見分けがつかないまでになっています。

しかしこのようなデータがインターネット上にあふれることにより、それらのデータが再帰的に学習に使用される事につながった場合、元のデータに見られたようなデータの特徴が失われ、固定化されたデータのみを生成するようになったり、意味のあるデータを作成出来なくなるという、「再帰の呪い」について書かれたものです。

この論文については私が個人的に説明するよりも、非常に正確な翻訳と説明をしてらっしゃる方がいますので、詳しい説明は その方の解説 をご覧下さい。

ここでは、実際に生成系AIが生成したデータをn次の形で再度生成系AIが学習に用いることにより、実際のデータに見られる特徴が失われていく過程を非常に単純化されたモデルを使用して、視覚的にわかりやすく解説することを意図しています。

論文の中ではデータの崩壊のプロセスについていくつかの原因と、データが崩壊した場合のモードについての異なる結論があり得る可能性について述べていますが、今回は生成系AIが生成したデータが複合ガウス分布に載るという過程を置いた上で

「生成系AIの生成するデータにはノイズが載る」という前提

「ノイズはサンプル平均の偏差よりも大きいオーダーで有界」

という前提を満たすように調整した上で、複合ガウス分布に乗るサンプルにノイズを加えています。
具体的には以下のコードで、2種類の複合ガウス分布を算出し、その2種類の複合ガウス分布の平均の偏差よりも大きいオーダー、かつ有界(その集合の要素全てが特定の最小値と最大値の間に存在すること)を満たすように調整しています。

複合ガウス分布のデータの生成コード

以下が、データの崩壊のプロセスを極単純化し、視覚的にわかりやすく解説することを意図したコードになります。


# オリジナルの複合ガウス分布データのパラメータ
means = [2, 10]
std_devs = [1, 1]

# 過程を再現する世代数
total_generations = 100

# 1世代での変化が少ないため、10世代ごとのグラフを描画する
sampling_generations = 10

# サンプル数
samples_per_generation = 1000

# オリジナルデータ分布の再現

# データを保持するために空のリストを初期化する。
original_data = []

# 平均と標準偏差の各ペアを繰り返し処理する。
for mean, std_dev in zip(means, std_devs):
    # 正規分布からのサンプルデータを生成する。
    data = np.random.normal(mean, std_dev, samples_per_generation // 2)

    # サンプリングされたデータをリストに追加する
    original_data.extend(data)

# リストをnumpy配列に変換する。
original_data = np.array(original_data)


# グラフの設定を行う。
fig, axes = plt.subplots(total_generations // sampling_generations, 1, figsize=(5, 30), sharex=True)
plt.subplots_adjust(top=0.95, bottom=0.02, hspace=0.01)

# 世代別データ分布のプロット
data = original_data.copy()
for generation in range(total_generations):
    if generation % sampling_generations == 0:
        ax = axes[generation // sampling_generations]
        # ヒストグラムを確率密度として正規化する(確率の合計は1になる)
        ax.hist(original_data, bins=50, density=True, alpha=0.3, color='red', label='Original Data')
        ax.hist(data, bins=50, density=True, alpha=0.5, label=f'Generation {generation} Generated Data')
        ax.legend()

    # ガウス混合モデルのフィッティング
    gmm = GaussianMixture(n_components=2)
    gmm.fit(data.reshape(-1, 1))

    # 適合モデルから新しいデータをサンプリングする
    data = gmm.sample(samples_per_generation)[0].flatten()

    # 付加的なノイズは有界であり、標本平均偏差より大きなオーダーである。
    sigma = 0.5
    noise = np.array([])

    # ノイズを生成する
    while len(noise) < samples_per_generation:
        temp_noise = sigma * np.random.randn(samples_per_generation)
        # 論文での前提となる、ノイズが有界という条件を満たしておきたい(実際はこの処理がなくても有界になるはず)
        temp_noise = temp_noise[(temp_noise >= -3 * sigma) & (temp_noise <= 3 * sigma)]
        noise = np.concatenate((noise, temp_noise))

    # ノイズの要素をサンプリング数に制限する
    noise = noise[:samples_per_generation]

    data += noise

plt.xlabel('Value')
plt.ylabel('Density')
plt.suptitle('Model Collapse Simulation')
plt.show()

生成される、データ分布のヒストグラムは以下のようになります。

複合ガウス分布の崩壊プロセスの簡易的例示.jpg

極単純化したものではありますが、いくつかの前提の元でオリジナルのデータ分布が失われていく様子が視覚的に理解出来るかと思います。

Papers with Code にもコードを添付しています(論文を取り上げたのは私ではありません)

この論文で触れられているようなデータやモデルの崩壊は、生成モデルの訓練における重要な問題であると共にそれを防ぐためには、真の人間が生成したコンテンツを学習に使用する事が不可欠であるという事です。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0