0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ARCコンペ:代数をEDAに適用する

Last updated at Posted at 2024-07-20

このノートブックでは、同じ次元を持つ正方行列の例のみを使用します。データセットには、他の分布も含まれています。

数字を使った例をプロットすることで、例の数学的分布を視覚化することができます。また、引き算によって、入力と出力の違いを明確に理解することができます。

ライブラリのインポート

まずは必要なライブラリをインポートします。

import json  # JSON データの読み込み用
import pandas as pd  # データの操作用
import numpy as np  # 数値計算用
import matplotlib.pyplot as plt  # グラフ描画用
from   matplotlib import colors  # 色の設定用

データの読み込み

データセットを読み込みます。

base_path = '/kaggle/input/arc-prize-2024/'

# JSON データを読み込む関数
def load_json(file_path):
    with open(file_path) as f:
        data = json.load(f)
    return data

# このノートブックでは、正方形で等しい行列の例のみを使用します。
# データセットには、異なる分布を持つ他のcsvファイルも含まれています。
df = pd.read_csv("/kaggle/input/arc-2024-training-explamples-by-form/equals_squared_train.csv")   
training_challenges = load_json(base_path + 'arc-agi_training_challenges.json')

入力、出力、引き算の抽出

入力と出力の行列を取得し、その差分を計算する関数を定義します。

# 行列を読み込む関数
def get_matrix_pair(challenge):
    x = pd.DataFrame(challenge['input'])  # 入力の行列
    y = pd.DataFrame(challenge['output']) # 出力の行列
    
    # 引き算を行う (他の演算に変更することも可能)
    z = y - x  
    return x, y, z

色の設定

ヒートマップに使用する色を設定します。

# 色のリストを定義
cmap = colors.ListedColormap(
    ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
     '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])

# 0から9までの値を色に正規化
norm = colors.Normalize(vmin=0, vmax=9)

# 色の見本を表示
plt.figure(figsize=(3, 1), dpi=150) 
plt.imshow([list(range(10))], cmap=cmap, norm=norm)
plt.xticks(list(range(10)))
plt.yticks([])
plt.show()

ヒートマップの作成

数字付きのヒートマップを作成するための関数を定義します。

def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw=None, cbarlabel="", **kwargs):
    """
    NumPy 配列と2つのラベルリストからヒートマップを作成する関数

    Parameters
    ----------
    data: numpy.ndarray
        (M, N) 形状の2次元 NumPy 配列
    row_labels: list または numpy.ndarray
        行のラベルを含む、長さ M のリストまたは配列
    col_labels: list または numpy.ndarray
        列のラベルを含む、長さ N のリストまたは配列
    ax: matplotlib.axes.Axes, optional
        ヒートマップを描画する Axes インスタンス。指定しない場合は、現在の Axes を使用するか、新しい Axes を作成します。
    cbar_kw: dict, optional
        matplotlib.figure.Figure.colorbar に渡す引数の辞書。
    cbarlabel: str, optional
        カラーバーのラベル。
    **kwargs: 
        imshow に渡すその他の引数。
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    # ヒートマップを描画
    im = ax.imshow(data, **kwargs)

    # カラーバーを作成
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # すべての目盛りを表示し、それぞれのリストのエントリでラベル付けします。
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    # 横軸のラベルを上に表示します。
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # 目盛ラベルを回転させて配置します。
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
             rotation_mode="anchor")

    # 枠線をオフにし、白いグリッドを作成します。
    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar

# ヒートマップに数値のテキストを追加する関数
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None, **textkw):
    """
    ヒートマップに注釈を付ける関数

    Parameters
    ----------
    im: matplotlib.image.AxesImage
        ラベル付けする AxesImage
    data: numpy.ndarray, optional
        注釈に使用するデータ。None の場合、画像のデータが使用されます。
    valfmt: str または matplotlib.ticker.Formatter, optional
        ヒートマップ内の注釈の書式。文字列の書式設定メソッド (例: "$ {x:.2f}") 
        または matplotlib.ticker.Formatter を使用します。
    textcolors: tuple, optional
        色のペア。最初はしきい値以下の値に使用され、2番目はしきい値以上の値に使用されます。
    threshold: float, optional
        textcolors からの色が適用されるデータ単位の値。
        None (デフォルト) の場合は、カラーマップの中間値が区切りとして使用されます。
    **kwargs: 
        テキストラベルの作成に使用される text の呼び出しごとに転送される、その他すべての引数。
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # しきい値を画像の色範囲に正規化します。
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    # デフォルトの配置を中央に設定しますが、textkw で上書きできるようにします。
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # 文字列が指定されている場合は、フォーマッターを取得します
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # データをループ処理し、各「ピクセル」の Text を作成します。
    # データに応じてテキストの色を変更します。
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

問題と解答の表示

問題、入力、出力、入力と出力の差を表示する関数を定義します。

def ploting_exercices(challenge, x, y, z):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))  # 3つのプロットを作成
    fig.suptitle(challenge) # タイトルに問題IDを表示
    cmap = colors.ListedColormap(['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
                                      '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25']) # 色のリスト
    norm = colors.Normalize(vmin=0, vmax=9) # 0~9の値を色に正規化

    # 入力行列を表示
    ax1.imshow(x, cmap=cmap, norm=norm)
    ax1.set_title('Input') # タイトル
    # 出力行列を表示
    ax2.imshow(y, cmap=cmap, norm=norm)
    ax2.set_title('Output') # タイトル
    # 入出力の差分を表示
    ax3.imshow(z, cmap=cmap, norm=norm)
    ax3.set_title('(output) - (input)') # タイトル
    
    # 入出力、差分の各行列の数字を表示
    for i in range(len(x[0])):
        for j in range(len(x[0])):
            text = ax1.text(i, j, x[i][j],
                           ha="center", va="center", color="r", size=15)
    
    for i in range(len(x[0])):
        for j in range(len(x[0])):
            text = ax2.text(i, j, y[i][j],
                           ha="center", va="center", color="r", size=15)
    
    for i in range(len(x[0])):
        for j in range(len(x[0])):
            text = ax3.text(i, j, z[i][j],
                           ha="center", va="center", color="r", size=15)
    
    
    fig.tight_layout()
    plt.show()

10個の例をプロット

10個の例について、問題、入力、出力、入力と出力の差を表示します。

# 最初の10個のデータを取得
batch_1 = df['id'][0:10]  
# 10個のデータについて、問題、入力、出力、入力と出力の差を表示
for challenge_id in batch_1:
    train_dic = training_challenges[challenge_id]['train']
    for pair in train_dic:
        x, y, z = get_matrix_pair(pair)
        ploting_exercices(challenge_id, x, y, z)

改善点

  • 転置や他の演算を試してみてください。
  • 入力と出力の関係を分類問題として分析するために、解答を含まないデータセットを作成してみてください。
  • 行列のパターンを探索するために、グラフを実装することを検討してみてください。

ノートブック

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?