C
numpy
コードリーディング

実録!コードリーディング入門 ~ NumPy rollaxis & transpose 編 ~ (2/2)

この記事は、実録!コードリーディング入門 ~ NumPy rollaxis & transpose 編 ~(1/2) の続きです。

前回は、numpy.rollaxis の実態が numpy.transpose であることをまず突き止めました。
その後、実装を深く潜っていくと numpy.ndarray のC実装にたどり着き、 PyArray_Transpose() 関数が本体であるということがわかりました。
この記事では、 numpy.rollaxis = numpy.transpose の実装をより深く追っていきます。

その前に: 可変長引数の扱いについて

numpy.ndarray.transpose のドキュメント を見ると、この関数は axes という可変長の引数を取ります。
具体的には、次の3パターンを取ります。

  • 引数なし or None: 軸を逆順にする
  • 整数の tuple: tupleのj番目にiと書かれている場合、元配列のi番目の軸を、j番目に移動する
  • N個の整数: 意味合いはtupleと同じだが、引数の個数が1個と解釈されるか、N個と解釈されるかが異なる

このような可変長引数を扱っている処理が、 array_transpose() 関数に書かれています。
私が追記したコメントの通り、それぞれのケースに対応する処理が書かれていることがわかります。

numpy/core/src/multiarray/methods.c
static PyObject *
array_transpose(PyArrayObject *self, PyObject *args)
{
    PyObject *shape = Py_None;
    Py_ssize_t n = PyTuple_Size(args);
    PyArray_Dims permute;
    PyObject *ret;

    if (n > 1) {                              /* N個の整数の場合(3つ目のケース) */
        shape = args;
    }
    else if (n == 1) {                        /* 整数のタプルのケース(2つ目のケース) */
        shape = PyTuple_GET_ITEM(args, 0);
    }

    if (shape == Py_None) {                   /* 引数がないケース(1つ目のケース) */
        ret = PyArray_Transpose(self, NULL);
    }
    else {
        if (!PyArray_IntpConverter(shape, &permute)) {
            return NULL;
        }
        ret = PyArray_Transpose(self, &permute);
        npy_free_cache_dim_obj(permute);
    }

    return ret;
}

以上のコードから、 PyArray_Transpose() は以下の2つの引数を取ることがわかりました。

  • 多次元配列オブジェクトそのもの (self)
  • 軸の並び方を指定する配列オブジェクト (permute) なお、 NULL の場合は軸を逆順に並び替える

なお、細かい処理を完全に理解するためには PyArray_IntpConverter() などといった細々とした関数の中身も追っていきたいものですが、本質的に知りたい処理ではないことと、関数名から処理内容がだいたい予想がつく(オブジェクトの中身を整数の配列に型変換する)ことから、ここでは割愛しています。
コードリーディングにかけられる時間は有限ですので、このように本筋ではない関数はあえて追わないという姿勢も重要です。

PyArray_Transpose 関数を読む

前回の記事 で、 PyArray_Transpose()numpy/core/src/multiarray/shape.c の683行目にあるということまでわかっていました。
それでは、お目当てのコードを見てみましょう。
長いので、前半と後半に分けてみていきたいと思います。

前半: 軸入れ替え情報のチェック・生成

PyArray_Transpose() は、多次元配列オブジェクト ap の軸を、 permute で指定されたとおりに並び替える処理だということが、これまでのコードリーディングで読めてきました。
この permute という引数は NULL を取ることができ、この場合は軸を逆順に並び替える事になります。
このように、 permute は空の場合に何かしら軸情報を生成する必要がありますし、非 NULL だったとしても、適切な軸情報が渡されているか値チェックする必要があります。
そこで、この関数では permute をそのままその後の処理に使うのではなくて、一度 permutation という配列に変換して、以降の処理を行っています。

numpy/core/src/multiarray/shape.c
NPY_NO_EXPORT PyObject *
PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute)
{
    npy_intp *axes;
    int i, n;
    int permutation[NPY_MAXDIMS], reverse_permutation[NPY_MAXDIMS];
    PyArrayObject *ret = NULL;
    int flags;

    if (permute == NULL) {
        n = PyArray_NDIM(ap);
        for (i = 0; i < n; i++) {
            permutation[i] = n-1-i;
        }
    }
    else {
        n = permute->len;
        axes = permute->ptr;
        if (n != PyArray_NDIM(ap)) {
            PyErr_SetString(PyExc_ValueError,
                            "axes don't match array");
            return NULL;
        }
        for (i = 0; i < n; i++) {
            reverse_permutation[i] = -1;
        }
        for (i = 0; i < n; i++) {
            int axis = axes[i];
            if (check_and_adjust_axis(&axis, PyArray_NDIM(ap)) < 0) {
                return NULL;
            }
            if (reverse_permutation[axis] != -1) {
                PyErr_SetString(PyExc_ValueError,
                                "repeated axis in transpose");
                return NULL;
            }
            reverse_permutation[axis] = i;
            permutation[i] = axis;
        }
    }

軸を逆順に並べ替える (permute == NULL の場合)

if文に書かれている通り、 permute == NULL の場合は、軸を逆順に並べ替えるための軸入れ替え情報を生成します。
この処理は、for文の書き方として教科書ではおなじみのようなものですので、見ればひと目でわかると思います。
permutation[i] = n-1-i; とあるように、Cの配列は0始まりなので一番最後の要素は n-1 になるところが気をつける点ですね。
(こういうところで -1 を付け忘れてしまうのが、実際コードを書いてるとついやってしまうのよね。。。)

permutation は、元配列の軸が transpose 後にどこを行くかを表しています。
permutation[0] = n-1, permutation[1] = n-2 のように permutation が生成され、元行列の0番目の軸が n-1 番目(末尾)に、1番目の軸が n-2 番目に、…、n-1番目の軸が0番目(先頭)にと、逆順に並べ替えるような入れ替え情報が生成されました。

軸情報のチェック (permute != NULL の場合)

permute が非 NULL の場合は、与えられた permute の値をチェックし、 permutation に値を入れていきます。
値チェックで使われているのが、 reverse_permutation です。

まず、いちばん最初の if 文では、与えられた軸情報の要素数が、配列オブジェクトの次元数と一致することを確かめています。
ここは必ず一致するはずですので、この時点でしっかり確認するのが大事です。
コードを書くときは異常系もキッチリ意識しなければいけませんね!

1つ目のfor文は、 reverse_permutation の初期化です。
reverse_permutation はスタックに置かれるローカル変数ですので、初期値は不定です。
ローカル変数の初期化漏れはよく発生し、またコードをパッと見ただけでは気づきづらい厄介なバグになりますので、注意が必要です。
reverse_permutation は、transpose 後の配列の軸が、元配列のどの軸に対応しているかを表します。
変換前と変換後では、軸の対応が1対1になるはずですので、元配列のある複数の軸が変換後の軸に対応しないように、チェックするのがこの変数の役割です。

2つ目のfor文が、処理のメインです。
ここで読んでいる check_and_adjust_axis() は、与えられた軸が配列の次元を超えていないかと、負のインデックスが渡された時に正のインデックスに変換する役割を持っています。
軸のインデックスをすべて正の数に変換した後、 permutationreverse_permutation を生成していきます。
もし reverse_permutation に先に値が入れられていた場合、元配列の複数の軸が変換後のある軸に対応している事になりますので、そういった場合は ValueError を出しています。

後半: Transpose 処理の本体

ようやく transpose 処理の本体にたどり着きました!
コメントによると、transpose 処理は以下の3つのパートに分かれています。
コードリーディングをする際、コメントは非常に貴重な手がかりになるので気をつけて読みます。
(そしてそれは、コードを書く時にも読まれることを意識してコメントを残そう、ということでもあります)

  1. transpose 後の配列のメモリ領域の確保
  2. メモリ領域のポインタの設定
  3. transpose 後メモリの、次元とストライドの設定

特に最後の 「transpose 後メモリの、次元とストライドの設定」 がこの処理の鍵のようなので、そこを見ていきたいと思います。

numpy/core/src/multiarray/shape.c
    flags = PyArray_FLAGS(ap);

    /*
     * this allocates memory for dimensions and strides (but fills them
     * incorrectly), sets up descr, and points data at PyArray_DATA(ap).
     */
    Py_INCREF(PyArray_DESCR(ap));
    ret = (PyArrayObject *)
        PyArray_NewFromDescr(Py_TYPE(ap),
                             PyArray_DESCR(ap),
                             n, PyArray_DIMS(ap),
                             NULL, PyArray_DATA(ap),
                             flags,
                             (PyObject *)ap);
    if (ret == NULL) {
        return NULL;
    }
    /* point at true owner of memory: */
    Py_INCREF(ap);
    if (PyArray_SetBaseObject(ret, (PyObject *)ap) < 0) {
        Py_DECREF(ret);
        return NULL;
    }

    /* fix the dimensions and strides of the return-array */
    for (i = 0; i < n; i++) {
        PyArray_DIMS(ret)[i] = PyArray_DIMS(ap)[permutation[i]];
        PyArray_STRIDES(ret)[i] = PyArray_STRIDES(ap)[permutation[i]];
    }
    PyArray_UpdateFlags(ret, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS |
                        NPY_ARRAY_ALIGNED);
    return (PyObject *)ret;
}

多次元配列のデータ構造について

transpose 処理を追う前に、多次元配列のデータである PyArrayObject のデータ構造を軽く見ておきたいと思います。
定義されている場所を探すために $ find -name '*.h' -exec grep -nH PyArrayObject {} \; を実行した結果の抜粋を、以下に示します。

./numpy/core/include/numpy/ndarraytypes.h:664:typedef struct tagPyArrayObject_fields {
./numpy/core/include/numpy/ndarraytypes.h:702:} PyArrayObject_fields;
./numpy/core/include/numpy/ndarraytypes.h:712: * PyArrayObject field access is deprecated as of NumPy 1.7.
./numpy/core/include/numpy/ndarraytypes.h:714:typedef PyArrayObject_fields PyArrayObject;
./numpy/core/include/numpy/ndarraytypes.h:716:typedef struct tagPyArrayObject {
./numpy/core/include/numpy/ndarraytypes.h:718:} PyArrayObject;

numpy/core/include/numpy/ndarraytypes.h を見ると、 PyArrayObject_fields という構造体がまず存在し、それに対する typedef として PyArrayObject が定義されていることがわかりました。

numpy/core/include/numpy/ndarraytypes.h
typedef struct tagPyArrayObject_fields {
    PyObject_HEAD
    /* Pointer to the raw data buffer */
    char *data;
    /* The number of dimensions, also called 'ndim' */
    int nd;
    /* The size in each dimension, also called 'shape' */
    npy_intp *dimensions;
    /*
     * Number of bytes to jump to get to the
     * next element in each dimension
     */
    npy_intp *strides;
    /*
     * This object is decref'd upon
     * deletion of array. Except in the
     * case of UPDATEIFCOPY which has
     * special handling.
     *
     * For views it points to the original
     * array, collapsed so no chains of
     * views occur.
     *
     * For creation from buffer object it
     * points to an object that should be
     * decref'd on deletion
     *
     * For UPDATEIFCOPY flag this is an
     * array to-be-updated upon deletion
     * of this one
     */
    PyObject *base;
    /* Pointer to type structure */
    PyArray_Descr *descr;
    /* Flags describing array -- see below */
    int flags;
    /* For weak references */
    PyObject *weakreflist;
} PyArrayObject_fields;

typedef PyArrayObject_fields PyArrayObject;

PyArrayObject_fields を見てみると、多次元配列オブジェクトを構成するものとして、以下の様なものが挙げられています。

  • 生データの先頭ポインタ (data)
  • 配列の次元数 (nd)
  • 各次元の大きさ、あるいは shape (dimensions)
  • 各次元のストライド (strides)

このように、配列の元データは data という1次元配列に格納されていて、配列の形が dimensions に対応しています。
データにアクセスする際は、各次元のインデックスに対応する次元のストライドをかけて、それを全次元分足し合わせたものが data からのオフセットに相当することになると推測されます。

ここまでくると、 dimensionsstrides を書き換えることで、 transpose 処理ができそうだということが、データ構造から推測できるようになってきますね!

transpose 後メモリの、次元とストライドの設定

それでは本丸の transpose 処理の本丸を追ってみます。

numpy/core/src/multiarray/shape.c
    /* fix the dimensions and strides of the return-array */
    for (i = 0; i < n; i++) {
        PyArray_DIMS(ret)[i] = PyArray_DIMS(ap)[permutation[i]];
        PyArray_STRIDES(ret)[i] = PyArray_STRIDES(ap)[permutation[i]];
    }

ここに書かれている通り、配列の各次元の大きさとストライドを書き換えているだけのようです。
念のため PyArray_DIMS()PyArray_STRIDES() の実装も見てみましょう。
コマンドはおなじみ $ find -type f -exec grep -nH PyArray_DIMS {} \;$ find -type f -exec grep -nH PyArray_STRIDES {} \; です。

./numpy/core/include/numpy/ndarraytypes.h:1483:PyArray_STRIDES(PyArrayObject *arr)

numpy/core/include/numpy/ndarraytypes.h に目当ての実装がありそうです。
見てみると、多次元配列構造体の strides (各次元の大きさ = shape) と strides にアクセスするためのラッパーでした。

numpy/core/include/numpy/ndarraytypes.h
static NPY_INLINE npy_intp *
PyArray_STRIDES(PyArrayObject *arr)
{
    return ((PyArrayObject_fields *)arr)->strides;
}

static NPY_INLINE npy_intp
PyArray_STRIDE(const PyArrayObject *arr, int istride)
{
    return ((PyArrayObject_fields *)arr)->strides[istride];
}

まとめ

ここまで非常に長かったですが、一通り numpy.transpose の処理の流れを追うことができました。
ポイントをまとめると以下のようになります

  • numpy.transpose 処理の実体は PyArray_Transpose() @ numpy/core/src/multiarray/shape.c で定義されている
  • transpose 処理は、単純に shape とストライドを書き換えるだけの実装となっている
  • numpy.ndarray の詳細は PyArrayObject_fields @ numpy/core/include/numpy/ndarraytypes.h で、配列の shape とストライドが記録されており、書き換え後の shape とストライド情報を元にデータにアクセスすることで、transpose された行列を見せることができる

コードリーディングをする前は、てっきり配列のコピーを作って頑張って transpose しているのだとばっかり思っていましたが、実際はごくわずかなデータを書き換えるだけで transpose が実現されていました。
配列の大きさにかかわらず、一定の短時間で transpose をすることができ、とても効率の良い実装だと感心しました。

また、この実装を読んだことで、transpose 後にどのようなデータが見えてくるかが納得行くようになりました。
実例として、下記の様な (2,3,4) の配列 a に対し np.transpose(a, (2,1,0)) をした結果をもとに考えてみます。

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

↓↓ np.transpose(a, (2,1,0)) ↓↓

[[[ 0 12]
  [ 4 16]
  [ 8 20]]

 [[ 1 13]
  [ 5 17]
  [ 9 21]]

 [[ 2 14]
  [ 6 18]
  [10 22]]

 [[ 3 15]
  [ 7 19]
  [11 23]]]

まず、shape が (2,3,4) から (4,3,2) になっていることがわかります。
また、ストライドが元の (12,4,1) から (1,4,12) になるので、

  • 0番目の軸を進むたびに値が1増加
  • 1番目軸を進むたびに値が4増加
  • 2番目の軸を進むたびに値が12増加

という、値の見え方に納得が行くようになりました。

これで、transpose とそれを使っている rollaxis の挙動はバッチリ理解できるようになりましたね!

ドキュメントを読んでもどうしても挙動がわからないときは、このようにコードを読んでいくのが理解を進める確実な方法です。
巷にはなかなかコードリーディングの進め方を手取り足取り書いたドキュメントがないため、今回一歩ずつなるべく手順を省略することなくまとめていきました。
この記事がコードリーディングの仕方に悩むどなたかの目に触れて、少しでもお役に立てば幸いです。

なお、今回は誰でも使える標準ツールだけでコードリーディングをすることを目指し、下記の2ツールのみを使って実施しました。

  • find & grep: 実装箇所の探索
  • less: 実装内容の確認

実際には、コードリーディングを加速させてくれるツールが色々とありますので、そのへんの紹介もできればいいなと思いました。

それでは、 Happy Code Reading! 楽しくコードを読んでいきましょうね!

では。